Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88e84186
Unverified
Commit
88e84186
authored
Jun 14, 2021
by
Stas Bekman
Committed by
GitHub
Jun 14, 2021
Browse files
[style] consistent nn. and nn.functional: part 4 `examples` (#12156)
* consistent nn. and nn.functional: p4 examples * restore
parent
372ab9cd
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
18 deletions
+19
-18
examples/research_projects/pplm/run_pplm_discrim_train.py
examples/research_projects/pplm/run_pplm_discrim_train.py
+5
-5
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
...h_projects/seq2seq-distillation/_test_seq2seq_examples.py
+2
-1
examples/research_projects/seq2seq-distillation/distillation.py
...es/research_projects/seq2seq-distillation/distillation.py
+7
-8
examples/research_projects/seq2seq-distillation/finetune.py
examples/research_projects/seq2seq-distillation/finetune.py
+3
-2
examples/research_projects/wav2vec2/run_asr.py
examples/research_projects/wav2vec2/run_asr.py
+1
-1
examples/research_projects/wav2vec2/run_pretrain.py
examples/research_projects/wav2vec2/run_pretrain.py
+1
-1
No files found.
examples/research_projects/pplm/run_pplm_discrim_train.py
View file @
88e84186
...
@@ -23,10 +23,10 @@ import time
...
@@ -23,10 +23,10 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.utils.data
as
data
import
torch.utils.data
as
data
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
torch
import
nn
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
datasets
from
torchtext
import
datasets
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
...
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
...
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq
=
100
max_length_seq
=
100
class
Discriminator
(
torch
.
nn
.
Module
):
class
Discriminator
(
nn
.
Module
):
"""Transformer encoder followed by a Classification Head"""
"""Transformer encoder followed by a Classification Head"""
def
__init__
(
self
,
class_size
,
pretrained_model
=
"gpt2-medium"
,
cached_mode
=
False
,
device
=
"cpu"
):
def
__init__
(
self
,
class_size
,
pretrained_model
=
"gpt2-medium"
,
cached_mode
=
False
,
device
=
"cpu"
):
...
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
...
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
self
.
device
))
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
self
.
device
))
logits
=
self
.
classifier_head
(
avg_hidden
)
logits
=
self
.
classifier_head
(
avg_hidden
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
probs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
return
probs
return
probs
...
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
...
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
output_t
=
discriminator
(
input_t
)
output_t
=
discriminator
(
input_t
)
loss
=
F
.
nll_loss
(
output_t
,
target_t
)
loss
=
nn
.
functional
.
nll_loss
(
output_t
,
target_t
)
loss
.
backward
(
retain_graph
=
True
)
loss
.
backward
(
retain_graph
=
True
)
optimizer
.
step
()
optimizer
.
step
()
...
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
...
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t
,
target_t
=
input_t
.
to
(
device
),
target_t
.
to
(
device
)
input_t
,
target_t
=
input_t
.
to
(
device
),
target_t
.
to
(
device
)
output_t
=
discriminator
(
input_t
)
output_t
=
discriminator
(
input_t
)
# sum up batch loss
# sum up batch loss
test_loss
+=
F
.
nll_loss
(
output_t
,
target_t
,
reduction
=
"sum"
).
item
()
test_loss
+=
nn
.
functional
.
nll_loss
(
output_t
,
target_t
,
reduction
=
"sum"
).
item
()
# get the index of the max log-probability
# get the index of the max log-probability
pred_t
=
output_t
.
argmax
(
dim
=
1
,
keepdim
=
True
)
pred_t
=
output_t
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred_t
.
eq
(
target_t
.
view_as
(
pred_t
)).
sum
().
item
()
correct
+=
pred_t
.
eq
(
target_t
.
view_as
(
pred_t
)).
sum
().
item
()
...
...
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
View file @
88e84186
...
@@ -8,6 +8,7 @@ from pathlib import Path
...
@@ -8,6 +8,7 @@ from pathlib import Path
import
pytest
import
pytest
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch
import
nn
import
lightning_base
import
lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
...
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
...
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
logits
=
model
(
input_ids
,
attention_mask
=
mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
).
logits
logits
=
model
(
input_ids
,
attention_mask
=
mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
).
logits
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
smoothed_loss
,
nll_loss
=
label_smoothed_nll_loss
(
smoothed_loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
lm_labels
,
0.1
,
ignore_index
=
model
.
config
.
pad_token_id
lprobs
,
lm_labels
,
0.1
,
ignore_index
=
model
.
config
.
pad_token_id
)
)
...
...
examples/research_projects/seq2seq-distillation/distillation.py
View file @
88e84186
...
@@ -10,7 +10,6 @@ from typing import List
...
@@ -10,7 +10,6 @@ from typing import List
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
finetune
import
main
as
ft_main
...
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
...
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
assert
t_logits_slct
.
size
()
==
s_logits_slct
.
size
()
assert
t_logits_slct
.
size
()
==
s_logits_slct
.
size
()
loss_ce
=
(
loss_ce
=
(
self
.
ce_loss_fct
(
self
.
ce_loss_fct
(
F
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
nn
.
functional
.
log_softmax
(
s_logits_slct
/
self
.
temperature
,
dim
=-
1
),
F
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
),
nn
.
functional
.
softmax
(
t_logits_slct
/
self
.
temperature
,
dim
=-
1
),
)
)
*
(
self
.
temperature
)
**
2
*
(
self
.
temperature
)
**
2
)
)
...
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
...
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
if
self
.
hparams
.
label_smoothing
==
0
:
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
student_lm_loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
student_lm_loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
else
:
else
:
lprobs
=
F
.
log_softmax
(
lm_logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
student_lm_loss
,
_
=
label_smoothed_nll_loss
(
student_lm_loss
,
_
=
label_smoothed_nll_loss
(
lprobs
,
labels
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
lprobs
,
labels
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
)
)
...
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
...
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
assert
student_states
.
shape
==
teacher_states
.
shape
,
f
"
{
student_states
.
shape
}
!=
{
teacher_states
.
shape
}
"
assert
student_states
.
shape
==
teacher_states
.
shape
,
f
"
{
student_states
.
shape
}
!=
{
teacher_states
.
shape
}
"
if
normalize_hidden
:
if
normalize_hidden
:
student_states
=
F
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
student_states
=
nn
.
functional
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
teacher_states
=
F
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
teacher_states
=
nn
.
functional
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
mse
=
F
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
mse
=
nn
.
functional
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
return
masked_mse
return
masked_mse
...
...
examples/research_projects/seq2seq-distillation/finetune.py
View file @
88e84186
...
@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
...
@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch
import
nn
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
...
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
...
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
lm_logits
=
outputs
[
"logits"
]
lm_logits
=
outputs
[
"logits"
]
if
self
.
hparams
.
label_smoothing
==
0
:
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
ce_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
assert
lm_logits
.
shape
[
-
1
]
==
self
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
ce_loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
loss
=
ce_loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
else
:
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
lprobs
=
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
tgt_ids
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
lprobs
,
tgt_ids
,
self
.
hparams
.
label_smoothing
,
ignore_index
=
pad_token_id
)
)
...
...
examples/research_projects/wav2vec2/run_asr.py
View file @
88e84186
...
@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
...
@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
packaging
import
version
from
packaging
import
version
from
torch
import
nn
import
librosa
import
librosa
from
lang_trans
import
arabic
from
lang_trans
import
arabic
...
...
examples/research_projects/wav2vec2/run_pretrain.py
View file @
88e84186
...
@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
...
@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
from
datasets
import
DatasetDict
,
load_dataset
from
datasets
import
DatasetDict
,
load_dataset
from
packaging
import
version
from
packaging
import
version
from
torch
import
nn
import
librosa
import
librosa
from
transformers
import
(
from
transformers
import
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment