Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98cb7b2c
Unverified
Commit
98cb7b2c
authored
Apr 23, 2019
by
Thomas Wolf
Committed by
GitHub
Apr 23, 2019
Browse files
Merge pull request #445 from lukovnikov/master
Learning rate schedules improvement + extension
parents
68a889ee
69850b40
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
226 additions
and
103 deletions
+226
-103
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+165
-50
pytorch_pretrained_bert/optimization_openai.py
pytorch_pretrained_bert/optimization_openai.py
+19
-53
tests/optimization_test.py
tests/optimization_test.py
+42
-0
No files found.
pytorch_pretrained_bert/optimization.py
View file @
98cb7b2c
...
...
@@ -20,33 +20,157 @@ from torch.optim import Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
import
abc
import
sys
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
x_
=
(
x
-
warmup
)
/
(
1
-
warmup
)
# progress after warmup -
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
x_
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
return
x
/
warmup
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
if
sys
.
version_info
>=
(
3
,
4
):
ABC
=
abc
.
ABC
else
:
ABC
=
abc
.
ABCMeta
(
'ABC'
,
(),
{})
class
_LRSchedule
(
ABC
):
""" Parent of all LRSchedules here. """
warn_t_total
=
False
# is set to True for schedules where progressing beyond t_total steps doesn't make sense
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super
(
_LRSchedule
,
self
).
__init__
(
**
kw
)
if
t_total
<
0
:
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
warmup
=
max
(
warmup
,
0.
)
self
.
warmup
,
self
.
t_total
=
float
(
warmup
),
float
(
t_total
)
self
.
warned_for_t_total_at_progress
=
-
1
def
get_lr
(
self
,
step
,
nowarn
=
False
):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if
self
.
t_total
<
0
:
return
1.
progress
=
float
(
step
)
/
self
.
t_total
ret
=
self
.
get_lr_
(
progress
)
# warning for exceeding t_total (only active with warmup_linear
if
not
nowarn
and
self
.
warn_t_total
and
progress
>
1.
and
progress
>
self
.
warned_for_t_total_at_progress
:
logger
.
warning
(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.
format
(
ret
,
self
.
__class__
.
__name__
))
self
.
warned_for_t_total_at_progress
=
progress
# end warning
return
ret
@
abc
.
abstractmethod
def
get_lr_
(
self
,
progress
):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return
1.
class
ConstantLR
(
_LRSchedule
):
def
get_lr_
(
self
,
progress
):
return
1.
class
WarmupCosineSchedule
(
_LRSchedule
):
"""
Cosine learning rate schedule with linear warmup. Cosine after warmup is without restarts.
"""
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
.
5
,
**
kw
):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super
(
WarmupCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
**
kw
)
self
.
cycles
=
cycles
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupCosineWithHardRestartsSchedule
(
WarmupCosineSchedule
):
"""
Cosine learning rate schedule with linear warmup and hard restarts (if cycles > 1).
"""
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
assert
(
cycles
>=
1.
)
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
self
.
cycles
*
progress
)
%
1
)))
return
ret
class
WarmupCosineWithWarmupRestartsSchedule
(
WarmupCosineWithHardRestartsSchedule
):
"""
Cosine learning rate schedule with linear warmups and linear warmup restarts.
The same warmup rate is used for warmup restarts as for initial warmup.
The total effective fraction of warmup steps over all cycles is warmup * cycles!
"""
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
assert
(
warmup
*
cycles
<
1.
)
warmup
=
warmup
*
cycles
if
warmup
>=
0
else
warmup
super
(
WarmupCosineWithWarmupRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
def
get_lr_
(
self
,
progress
):
progress
=
progress
*
self
.
cycles
%
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
progress
))
return
ret
class
WarmupConstantSchedule
(
_LRSchedule
):
"""
Applies linear warmup. After warmup always returns 1..
"""
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
1.
class
WarmupLinearSchedule
(
_LRSchedule
):
"""
Linear warmup. Linear decay after warmup.
"""
warn_t_total
=
True
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0.
)
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
None
:
ConstantLR
,
"none"
:
ConstantLR
,
"warmup_cosine"
:
WarmupCosineSchedule
,
"warmup_constant"
:
WarmupConstantSchedule
,
"warmup_linear"
:
WarmupLinearSchedule
}
...
...
@@ -56,8 +180,10 @@ class BertAdam(Optimizer):
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
...
...
@@ -65,21 +191,26 @@ class BertAdam(Optimizer):
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def
__init__
(
self
,
params
,
lr
=
required
,
warmup
=-
1
,
t_total
=-
1
,
schedule
=
'warmup_linear'
,
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-6
,
weight_decay
=
0.01
,
max_grad_norm
=
1.0
):
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-6
,
weight_decay
=
0.01
,
max_grad_norm
=
1.0
,
**
kwargs
):
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
if
not
isinstance
(
schedule
,
_LRSchedule
)
and
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
# initialize schedule object
if
not
isinstance
(
schedule
,
_LRSchedule
):
schedule_type
=
SCHEDULES
[
schedule
]
schedule
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
else
:
if
warmup
!=
-
1
or
t_total
!=
-
1
:
logger
.
warning
(
"Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object."
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -91,11 +222,8 @@ class BertAdam(Optimizer):
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
return
[
0
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
return
lr
...
...
@@ -110,8 +238,6 @@ class BertAdam(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -153,19 +279,8 @@ class BertAdam(Optimizer):
if
group
[
'weight_decay'
]
>
0.0
:
update
+=
group
[
'weight_decay'
]
*
p
.
data
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
...
...
pytorch_pretrained_bert/optimization_openai.py
View file @
98cb7b2c
...
...
@@ -20,35 +20,11 @@ from torch.optim import Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
from
.optimization
import
SCHEDULES
,
_LRSchedule
,
WarmupCosineWithWarmupRestartsSchedule
,
\
WarmupCosineWithHardRestartsSchedule
,
WarmupCosineSchedule
,
WarmupLinearSchedule
,
WarmupConstantSchedule
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
x_
=
(
x
-
warmup
)
/
(
1
-
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
x_
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
return
x
/
warmup
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
}
class
OpenAIAdam
(
Optimizer
):
"""Implements Open AI version of Adam algorithm with weight decay fix.
...
...
@@ -58,17 +34,23 @@ class OpenAIAdam(Optimizer):
vector_l2
=
False
,
max_grad_norm
=-
1
,
**
kwargs
):
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
if
not
isinstance
(
schedule
,
_LRSchedule
)
and
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {}"
.
format
(
b1
))
raise
ValueError
(
"Invalid b1 parameter: {}
- should be in [0.0, 1.0[
"
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {}"
.
format
(
b2
))
raise
ValueError
(
"Invalid b2 parameter: {}
- should be in [0.0, 1.0[
"
.
format
(
b2
))
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
e
))
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
# initialize schedule object
if
not
isinstance
(
schedule
,
_LRSchedule
):
schedule_type
=
SCHEDULES
[
schedule
]
schedule
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
else
:
if
warmup
!=
-
1
or
t_total
!=
-
1
:
logger
.
warning
(
"Non-default warmup and t_total are ineffective when LRSchedule object is provided. "
"Please specify custom warmup and t_total in LRSchedule object."
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
vector_l2
=
vector_l2
,
max_grad_norm
=
max_grad_norm
)
super
(
OpenAIAdam
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -80,11 +62,8 @@ class OpenAIAdam(Optimizer):
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
return
[
0
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
return
lr
...
...
@@ -99,8 +78,6 @@ class OpenAIAdam(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -136,19 +113,8 @@ class OpenAIAdam(Optimizer):
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
step_size
=
lr_scheduled
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
...
...
tests/optimization_test.py
View file @
98cb7b2c
...
...
@@ -21,6 +21,10 @@ import unittest
import
torch
from
pytorch_pretrained_bert
import
BertAdam
from
pytorch_pretrained_bert
import
OpenAIAdam
from
pytorch_pretrained_bert.optimization
import
ConstantLR
,
WarmupLinearSchedule
,
WarmupCosineWithWarmupRestartsSchedule
import
numpy
as
np
class
OptimizationTest
(
unittest
.
TestCase
):
...
...
@@ -46,5 +50,43 @@ class OptimizationTest(unittest.TestCase):
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
class
ScheduleInitTest
(
unittest
.
TestCase
):
def
test_bert_sched_init
(
self
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
# shouldn't fail
def
test_openai_sched_init
(
self
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
# shouldn't fail
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
def
test_it
(
self
):
m
=
WarmupCosineWithWarmupRestartsSchedule
(
warmup
=
0.05
,
t_total
=
1000.
,
cycles
=
5
)
x
=
np
.
arange
(
0
,
1000
)
y
=
[
m
.
get_lr
(
xe
)
for
xe
in
x
]
y
=
np
.
asarray
(
y
)
expected_zeros
=
y
[[
0
,
200
,
400
,
600
,
800
]]
print
(
expected_zeros
)
expected_ones
=
y
[[
50
,
250
,
450
,
650
,
850
]]
print
(
expected_ones
)
self
.
assertTrue
(
np
.
allclose
(
expected_ones
,
1
))
self
.
assertTrue
(
np
.
allclose
(
expected_zeros
,
0
))
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment