Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ed2ebf6
Unverified
Commit
1ed2ebf6
authored
Jun 14, 2021
by
Stas Bekman
Committed by
GitHub
Jun 14, 2021
Browse files
[style] consistent nn. and nn.functional (#12124)
* consistent nn. and nn.functional * fix glitch * fix glitch #2
parent
ff7c8168
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
src/transformers/optimization.py
src/transformers/optimization.py
+4
-3
src/transformers/trainer.py
src/transformers/trainer.py
+5
-5
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-1
No files found.
src/transformers/optimization.py
View file @
1ed2ebf6
...
@@ -18,6 +18,7 @@ import math
...
@@ -18,6 +18,7 @@ import math
from
typing
import
Callable
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
LambdaLR
from
torch.optim.lr_scheduler
import
LambdaLR
...
@@ -272,7 +273,7 @@ class AdamW(Optimizer):
...
@@ -272,7 +273,7 @@ class AdamW(Optimizer):
<https://arxiv.org/abs/1711.05101>`__.
<https://arxiv.org/abs/1711.05101>`__.
Parameters:
Parameters:
params (:obj:`Iterable[
torch.
nn.parameter.Parameter]`):
params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`, defaults to 1e-3):
lr (:obj:`float`, `optional`, defaults to 1e-3):
The learning rate to use.
The learning rate to use.
...
@@ -288,7 +289,7 @@ class AdamW(Optimizer):
...
@@ -288,7 +289,7 @@ class AdamW(Optimizer):
def
__init__
(
def
__init__
(
self
,
self
,
params
:
Iterable
[
torch
.
nn
.
parameter
.
Parameter
],
params
:
Iterable
[
nn
.
parameter
.
Parameter
],
lr
:
float
=
1e-3
,
lr
:
float
=
1e-3
,
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
),
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
),
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
...
@@ -379,7 +380,7 @@ class Adafactor(Optimizer):
...
@@ -379,7 +380,7 @@ class Adafactor(Optimizer):
`relative_step=False`.
`relative_step=False`.
Arguments:
Arguments:
params (:obj:`Iterable[
torch.
nn.parameter.Parameter]`):
params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`):
lr (:obj:`float`, `optional`):
The external learning rate.
The external learning rate.
...
...
src/transformers/trainer.py
View file @
1ed2ebf6
...
@@ -264,7 +264,7 @@ class Trainer:
...
@@ -264,7 +264,7 @@ class Trainer:
def
__init__
(
def
__init__
(
self
,
self
,
model
:
Union
[
PreTrainedModel
,
torch
.
nn
.
Module
]
=
None
,
model
:
Union
[
PreTrainedModel
,
nn
.
Module
]
=
None
,
args
:
TrainingArguments
=
None
,
args
:
TrainingArguments
=
None
,
data_collator
:
Optional
[
DataCollator
]
=
None
,
data_collator
:
Optional
[
DataCollator
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
...
@@ -772,7 +772,7 @@ class Trainer:
...
@@ -772,7 +772,7 @@ class Trainer:
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
"""
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
decay_parameters
=
get_parameter_names
(
self
.
model
,
[
torch
.
nn
.
LayerNorm
])
decay_parameters
=
get_parameter_names
(
self
.
model
,
[
nn
.
LayerNorm
])
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
{
{
...
@@ -933,7 +933,7 @@ class Trainer:
...
@@ -933,7 +933,7 @@ class Trainer:
# Multi-gpu training (should be after apex fp16 initialization)
# Multi-gpu training (should be after apex fp16 initialization)
if
self
.
args
.
n_gpu
>
1
:
if
self
.
args
.
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
model
=
nn
.
DataParallel
(
model
)
# Note: in torch.distributed mode, there's no point in wrapping the model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
...
@@ -970,7 +970,7 @@ class Trainer:
...
@@ -970,7 +970,7 @@ class Trainer:
find_unused_parameters
=
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
)
find_unused_parameters
=
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
)
else
:
else
:
find_unused_parameters
=
True
find_unused_parameters
=
True
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
model
,
device_ids
=
[
self
.
args
.
local_rank
],
device_ids
=
[
self
.
args
.
local_rank
],
output_device
=
self
.
args
.
local_rank
,
output_device
=
self
.
args
.
local_rank
,
...
@@ -1288,7 +1288,7 @@ class Trainer:
...
@@ -1288,7 +1288,7 @@ class Trainer:
model
.
clip_grad_norm_
(
args
.
max_grad_norm
)
model
.
clip_grad_norm_
(
args
.
max_grad_norm
)
else
:
else
:
# Revert to normal clipping otherwise, handling Apex or full precision
# Revert to normal clipping otherwise, handling Apex or full precision
torch
.
nn
.
utils
.
clip_grad_norm_
(
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
)
if
self
.
use_apex
else
model
.
parameters
(),
amp
.
master_params
(
self
.
optimizer
)
if
self
.
use_apex
else
model
.
parameters
(),
args
.
max_grad_norm
,
args
.
max_grad_norm
,
)
)
...
...
src/transformers/trainer_pt_utils.py
View file @
1ed2ebf6
...
@@ -28,6 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
...
@@ -28,6 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
from
torch
import
nn
from
torch.utils.data.dataset
import
Dataset
,
IterableDataset
from
torch.utils.data.dataset
import
Dataset
,
IterableDataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
...
@@ -441,7 +442,7 @@ class LabelSmoother:
...
@@ -441,7 +442,7 @@ class LabelSmoother:
def
__call__
(
self
,
model_output
,
labels
):
def
__call__
(
self
,
model_output
,
labels
):
logits
=
model_output
[
"logits"
]
if
isinstance
(
model_output
,
dict
)
else
model_output
[
0
]
logits
=
model_output
[
"logits"
]
if
isinstance
(
model_output
,
dict
)
else
model_output
[
0
]
log_probs
=
-
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
log_probs
=
-
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
if
labels
.
dim
()
==
log_probs
.
dim
()
-
1
:
if
labels
.
dim
()
==
log_probs
.
dim
()
-
1
:
labels
=
labels
.
unsqueeze
(
-
1
)
labels
=
labels
.
unsqueeze
(
-
1
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment