Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1ed2ebf6
Unverified
Commit
1ed2ebf6
authored
Jun 14, 2021
by
Stas Bekman
Committed by
GitHub
Jun 14, 2021
Browse files
[style] consistent nn. and nn.functional (#12124)
* consistent nn. and nn.functional * fix glitch * fix glitch #2
parent
ff7c8168
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
src/transformers/optimization.py
src/transformers/optimization.py
+4
-3
src/transformers/trainer.py
src/transformers/trainer.py
+5
-5
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-1
No files found.
src/transformers/optimization.py
View file @
1ed2ebf6
...
...
@@ -18,6 +18,7 @@ import math
from
typing
import
Callable
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
LambdaLR
...
...
@@ -272,7 +273,7 @@ class AdamW(Optimizer):
<https://arxiv.org/abs/1711.05101>`__.
Parameters:
params (:obj:`Iterable[
torch.
nn.parameter.Parameter]`):
params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`, defaults to 1e-3):
The learning rate to use.
...
...
@@ -288,7 +289,7 @@ class AdamW(Optimizer):
def
__init__
(
self
,
params
:
Iterable
[
torch
.
nn
.
parameter
.
Parameter
],
params
:
Iterable
[
nn
.
parameter
.
Parameter
],
lr
:
float
=
1e-3
,
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
),
eps
:
float
=
1e-6
,
...
...
@@ -379,7 +380,7 @@ class Adafactor(Optimizer):
`relative_step=False`.
Arguments:
params (:obj:`Iterable[
torch.
nn.parameter.Parameter]`):
params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`):
The external learning rate.
...
...
src/transformers/trainer.py
View file @
1ed2ebf6
...
...
@@ -264,7 +264,7 @@ class Trainer:
def
__init__
(
self
,
model
:
Union
[
PreTrainedModel
,
torch
.
nn
.
Module
]
=
None
,
model
:
Union
[
PreTrainedModel
,
nn
.
Module
]
=
None
,
args
:
TrainingArguments
=
None
,
data_collator
:
Optional
[
DataCollator
]
=
None
,
train_dataset
:
Optional
[
Dataset
]
=
None
,
...
...
@@ -772,7 +772,7 @@ class Trainer:
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if
self
.
optimizer
is
None
:
decay_parameters
=
get_parameter_names
(
self
.
model
,
[
torch
.
nn
.
LayerNorm
])
decay_parameters
=
get_parameter_names
(
self
.
model
,
[
nn
.
LayerNorm
])
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
optimizer_grouped_parameters
=
[
{
...
...
@@ -933,7 +933,7 @@ class Trainer:
# Multi-gpu training (should be after apex fp16 initialization)
if
self
.
args
.
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
model
=
nn
.
DataParallel
(
model
)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
...
...
@@ -970,7 +970,7 @@ class Trainer:
find_unused_parameters
=
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
)
else
:
find_unused_parameters
=
True
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
self
.
args
.
local_rank
],
output_device
=
self
.
args
.
local_rank
,
...
...
@@ -1288,7 +1288,7 @@ class Trainer:
model
.
clip_grad_norm_
(
args
.
max_grad_norm
)
else
:
# Revert to normal clipping otherwise, handling Apex or full precision
torch
.
nn
.
utils
.
clip_grad_norm_
(
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
)
if
self
.
use_apex
else
model
.
parameters
(),
args
.
max_grad_norm
,
)
...
...
src/transformers/trainer_pt_utils.py
View file @
1ed2ebf6
...
...
@@ -28,6 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
import
numpy
as
np
import
torch
from
packaging
import
version
from
torch
import
nn
from
torch.utils.data.dataset
import
Dataset
,
IterableDataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
...
...
@@ -441,7 +442,7 @@ class LabelSmoother:
def
__call__
(
self
,
model_output
,
labels
):
logits
=
model_output
[
"logits"
]
if
isinstance
(
model_output
,
dict
)
else
model_output
[
0
]
log_probs
=
-
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
log_probs
=
-
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
if
labels
.
dim
()
==
log_probs
.
dim
()
-
1
:
labels
=
labels
.
unsqueeze
(
-
1
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment