Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3ced9b3e
Unverified
Commit
3ced9b3e
authored
Mar 08, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 08, 2021
Browse files
Check layer types for Optimizer construction (#10598)
* Check layer types for Optimizer construction * Duplicate class
parent
821d518e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
3 deletions
+71
-3
src/transformers/trainer.py
src/transformers/trainer.py
+5
-3
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+16
-0
tests/test_trainer.py
tests/test_trainer.py
+26
-0
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+24
-0
No files found.
src/transformers/trainer.py
View file @
3ced9b3e
...
@@ -80,6 +80,7 @@ from .trainer_pt_utils import (
...
@@ -80,6 +80,7 @@ from .trainer_pt_utils import (
SequentialDistributedSampler
,
SequentialDistributedSampler
,
distributed_broadcast_scalars
,
distributed_broadcast_scalars
,
distributed_concat
,
distributed_concat
,
get_parameter_names
,
nested_concat
,
nested_concat
,
nested_detach
,
nested_detach
,
nested_numpify
,
nested_numpify
,
...
@@ -613,14 +614,15 @@ class Trainer:
...
@@ -613,14 +614,15 @@ class Trainer:
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
"""
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
decay_parameters
=
get_parameter_names
(
self
.
model
,
[
torch
.
nn
.
LayerNorm
])
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
{
{
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
n
ot
any
(
nd
in
n
for
nd
in
no_decay
)
],
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
n
in
decay_parameters
],
"weight_decay"
:
self
.
args
.
weight_decay
,
"weight_decay"
:
self
.
args
.
weight_decay
,
},
},
{
{
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_
decay
)
],
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
n
not
in
decay
_parameters
],
"weight_decay"
:
0.0
,
"weight_decay"
:
0.0
,
},
},
]
]
...
...
src/transformers/trainer_pt_utils.py
View file @
3ced9b3e
...
@@ -672,3 +672,19 @@ def save_state(self):
...
@@ -672,3 +672,19 @@ def save_state(self):
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"trainer_state.json"
)
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"trainer_state.json"
)
self
.
state
.
save_to_json
(
path
)
self
.
state
.
save_to_json
(
path
)
def
get_parameter_names
(
model
,
forbidden_layer_types
):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
result
=
[]
for
name
,
child
in
model
.
named_children
():
result
+=
[
f
"
{
name
}
.
{
n
}
"
for
n
in
get_parameter_names
(
child
,
forbidden_layer_types
)
if
not
isinstance
(
child
,
tuple
(
forbidden_layer_types
))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result
+=
list
(
model
.
_parameters
.
keys
())
return
result
tests/test_trainer.py
View file @
3ced9b3e
...
@@ -193,6 +193,20 @@ if is_torch_available():
...
@@ -193,6 +193,20 @@ if is_torch_available():
loss
=
torch
.
nn
.
functional
.
mse_loss
(
y
,
labels
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
y
,
labels
)
return
(
loss
,
y
,
y
)
if
self
.
double_output
else
(
loss
,
y
)
return
(
loss
,
y
,
y
)
if
self
.
double_output
else
(
loss
,
y
)
class
TstLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
linear1
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
ln1
=
torch
.
nn
.
LayerNorm
(
hidden_size
)
self
.
linear2
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
ln2
=
torch
.
nn
.
LayerNorm
(
hidden_size
)
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
def
forward
(
self
,
x
):
h
=
self
.
ln1
(
torch
.
nn
.
functional
.
relu
(
self
.
linear1
(
x
)))
h
=
torch
.
nn
.
functional
.
relu
(
self
.
linear2
(
x
))
return
self
.
ln2
(
x
+
h
+
self
.
bias
)
def
get_regression_trainer
(
a
=
0
,
b
=
0
,
double_output
=
False
,
train_len
=
64
,
eval_len
=
64
,
pretrained
=
True
,
**
kwargs
):
def
get_regression_trainer
(
a
=
0
,
b
=
0
,
double_output
=
False
,
train_len
=
64
,
eval_len
=
64
,
pretrained
=
True
,
**
kwargs
):
label_names
=
kwargs
.
get
(
"label_names"
,
None
)
label_names
=
kwargs
.
get
(
"label_names"
,
None
)
train_dataset
=
RegressionDataset
(
length
=
train_len
,
label_names
=
label_names
)
train_dataset
=
RegressionDataset
(
length
=
train_len
,
label_names
=
label_names
)
...
@@ -991,6 +1005,18 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -991,6 +1005,18 @@ class TrainerIntegrationTest(unittest.TestCase):
# perfect world: fp32_init/2 == fp16_eval
# perfect world: fp32_init/2 == fp16_eval
self
.
assertAlmostEqual
(
fp16_eval
,
fp32_init
/
2
,
delta
=
5_000
)
self
.
assertAlmostEqual
(
fp16_eval
,
fp32_init
/
2
,
delta
=
5_000
)
def
test_no_wd_param_group
(
self
):
model
=
torch
.
nn
.
Sequential
(
TstLayer
(
128
),
torch
.
nn
.
ModuleList
([
TstLayer
(
128
),
TstLayer
(
128
)]))
trainer
=
Trainer
(
model
=
model
)
trainer
.
create_optimizer_and_scheduler
(
10
)
# fmt: off
wd_names
=
[
'0.linear1.weight'
,
'0.linear2.weight'
,
'1.0.linear1.weight'
,
'1.0.linear2.weight'
,
'1.1.linear1.weight'
,
'1.1.linear2.weight'
]
# fmt: on
wd_params
=
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
in
wd_names
]
no_wd_params
=
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
not
in
wd_names
]
self
.
assertListEqual
(
trainer
.
optimizer
.
param_groups
[
0
][
"params"
],
wd_params
)
self
.
assertListEqual
(
trainer
.
optimizer
.
param_groups
[
1
][
"params"
],
no_wd_params
)
@
require_torch
@
require_torch
@
require_optuna
@
require_optuna
...
...
tests/test_trainer_utils.py
View file @
3ced9b3e
...
@@ -30,8 +30,23 @@ if is_torch_available():
...
@@ -30,8 +30,23 @@ if is_torch_available():
DistributedTensorGatherer
,
DistributedTensorGatherer
,
LabelSmoother
,
LabelSmoother
,
LengthGroupedSampler
,
LengthGroupedSampler
,
get_parameter_names
,
)
)
class
TstLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
linear1
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
ln1
=
torch
.
nn
.
LayerNorm
(
hidden_size
)
self
.
linear2
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
ln2
=
torch
.
nn
.
LayerNorm
(
hidden_size
)
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
def
forward
(
self
,
x
):
h
=
self
.
ln1
(
torch
.
nn
.
functional
.
relu
(
self
.
linear1
(
x
)))
h
=
torch
.
nn
.
functional
.
relu
(
self
.
linear2
(
x
))
return
self
.
ln2
(
x
+
h
+
self
.
bias
)
@
require_torch
@
require_torch
class
TrainerUtilsTest
(
unittest
.
TestCase
):
class
TrainerUtilsTest
(
unittest
.
TestCase
):
...
@@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase):
...
@@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase):
self
.
assertEqual
(
lengths
[
indices_process_0
[
0
]],
50
)
self
.
assertEqual
(
lengths
[
indices_process_0
[
0
]],
50
)
# The indices should be a permutation of range(100)
# The indices should be a permutation of range(100)
self
.
assertEqual
(
list
(
sorted
(
indices_process_0
+
indices_process_1
)),
list
(
range
(
100
)))
self
.
assertEqual
(
list
(
sorted
(
indices_process_0
+
indices_process_1
)),
list
(
range
(
100
)))
def
test_get_parameter_names
(
self
):
model
=
torch
.
nn
.
Sequential
(
TstLayer
(
128
),
torch
.
nn
.
ModuleList
([
TstLayer
(
128
),
TstLayer
(
128
)]))
# fmt: off
self
.
assertEqual
(
get_parameter_names
(
model
,
[
torch
.
nn
.
LayerNorm
]),
[
'0.linear1.weight'
,
'0.linear1.bias'
,
'0.linear2.weight'
,
'0.linear2.bias'
,
'0.bias'
,
'1.0.linear1.weight'
,
'1.0.linear1.bias'
,
'1.0.linear2.weight'
,
'1.0.linear2.bias'
,
'1.0.bias'
,
'1.1.linear1.weight'
,
'1.1.linear1.bias'
,
'1.1.linear2.weight'
,
'1.1.linear2.bias'
,
'1.1.bias'
]
)
# fmt: on
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment