Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88e84186
Unverified
Commit
88e84186
authored
Jun 14, 2021
by
Stas Bekman
Committed by
GitHub
Jun 14, 2021
Browse files
[style] consistent nn. and nn.functional: part 4 `examples` (#12156)
* consistent nn. and nn.functional: p4 examples * restore
parent
372ab9cd
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
18 deletions
+19
-18
examples/research_projects/pplm/run_pplm_discrim_train.py
examples/research_projects/pplm/run_pplm_discrim_train.py
+5
-5
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
...h_projects/seq2seq-distillation/_test_seq2seq_examples.py
+2
-1
examples/research_projects/seq2seq-distillation/distillation.py
...es/research_projects/seq2seq-distillation/distillation.py
+7
-8
examples/research_projects/seq2seq-distillation/finetune.py
examples/research_projects/seq2seq-distillation/finetune.py
+3
-2
examples/research_projects/wav2vec2/run_asr.py
examples/research_projects/wav2vec2/run_asr.py
+1
-1
examples/research_projects/wav2vec2/run_pretrain.py
examples/research_projects/wav2vec2/run_pretrain.py
+1
-1
No files found.
examples/research_projects/pplm/run_pplm_discrim_train.py
View file @
88e84186
...
...
@@ -23,10 +23,10 @@ import time
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.utils.data
as
data
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
torch
import
nn
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
datasets
from
tqdm
import
tqdm
,
trange
...
...
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq
=
100
class
Discriminator
(
torch
.
nn
.
Module
):
class
Discriminator
(
nn
.
Module
):
"""Transformer encoder followed by a Classification Head"""
def
__init__
(
self
,
class_size
,
pretrained_model
=
"gpt2-medium"
,
cached_mode
=
False
,
device
=
"cpu"
):
...
...
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
self
.
device
))
logits
=
self
.
classifier_head
(
avg_hidden
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
probs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
return
probs
...
...
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer
.
zero_grad
()
output_t
=
discriminator
(
input_t
)
loss
=
F
.
nll_loss
(
output_t
,
target_t
)
loss
=
nn
.
functional
.
nll_loss
(
output_t
,
target_t
)
loss
.
backward
(
retain_graph
=
True
)
optimizer
.
step
()
...
...
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t
,
target_t
=
input_t
.
to
(
device
),
target_t
.
to
(
device
)
output_t
=
discriminator
(
input_t
)
# sum up batch loss
test_loss
+=
F
.
nll_loss
(
output_t
,
target_t
,
reduction
=
"sum"
).
item
()
test_loss
+=
nn
.
functional
.
nll_loss
(
output_t
,
target_t
,
reduction
=
"sum"
).
item
()
# get the index of the max log-probability
pred_t
=
output_t
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred_t
.
eq
(
target_t
.
view_as
(
pred_t
)).
sum
().
item
()
...
...
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
View file @
88e84186
...
...
@@ -8,6 +8,7 @@ from pathlib import Path
import
pytest
import
pytorch_lightning
as
pl
import
torch
from
torch
import
nn
import
lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
...
...
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
logits
=
model
(
input_ids
,
attention_mask
=
mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
).
logits
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
smoothed_loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
lm_labels
,
0.1
,
ignore_index
=
model
.
config
.
pad_token_id
)
...
...
examples/research_projects/seq2seq-distillation/distillation.py
View file @
88e84186
...
...
@@ -10,7 +10,6 @@ from typing import List
import
pytorch_lightning
as
pl
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
...
...
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
assert
t_logits_slct
.
size
()
==
s_logits_slct
.
size
()
loss_ce
=
(
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
),
nn
.
functional
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
nn
.
functional
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
),
)
*
(
self
.
temperature
)
**
2
)
...
...
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
student_lm_loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
else
:
lprobs
=
F
.
log_softmax
(
lm_logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
student_lm_loss
,
_
=
label_smoothed_nll_loss
(
lprobs
,
labels
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
)
...
...
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
assert
student_states
.
shape
==
teacher_states
.
shape
,
f
"
{
student_states
.
shape
}
!=
{
teacher_states
.
shape
}
"
if
normalize_hidden
:
student_states
=
F
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
teacher_states
=
F
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
mse
=
F
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
student_states
=
nn
.
functional
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
teacher_states
=
nn
.
functional
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
mse
=
nn
.
functional
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
return
masked_mse
...
...
examples/research_projects/seq2seq-distillation/finetune.py
View file @
88e84186
...
...
@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
torch
import
nn
from
torch.utils.data
import
DataLoader
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
...
...
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
lm_logits
=
outputs
[
"logits"
]
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
ce_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
assert
lm_logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
ce_loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
tgt_ids
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
)
...
...
examples/research_projects/wav2vec2/run_asr.py
View file @
88e84186
...
...
@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import
datasets
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
packaging
import
version
from
torch
import
nn
import
librosa
from
lang_trans
import
arabic
...
...
examples/research_projects/wav2vec2/run_pretrain.py
View file @
88e84186
...
...
@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
datasets
import
DatasetDict
,
load_dataset
from
packaging
import
version
from
torch
import
nn
import
librosa
from
transformers
import
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment