Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dafa296c
Unverified
Commit
dafa296c
authored
Jul 28, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 28, 2020
Browse files
[s2s] Delete useless method, log tokens_per_batch (#6081)
parent
dc4755c6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
15 deletions
+14
-15
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+14
-9
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+0
-6
No files found.
examples/seq2seq/finetune.py
View file @
dafa296c
...
...
@@ -160,9 +160,16 @@ class SummarizationModule(BaseTransformer):
)
return
(
loss
,)
@
property
def
pad
(
self
)
->
int
:
return
self
.
tokenizer
.
pad_token_id
def
training_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
loss_tensors
=
self
.
_step
(
batch
)
logs
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
# tokens per batch
logs
[
"tpb"
]
=
batch
[
"input_ids"
].
ne
(
self
.
pad
).
sum
()
+
batch
[
"decoder_input_ids"
].
ne
(
self
.
pad
).
sum
()
return
{
"loss"
:
loss_tensors
[
0
],
"log"
:
logs
}
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
...
...
@@ -172,7 +179,7 @@ class SummarizationModule(BaseTransformer):
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"
summ
_len"
]}
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"
gen
_len"
]}
rouge_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
rouges
[
self
.
val_metric
]).
type_as
(
loss
)
rouges
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
losses
.
update
(
rouges
)
...
...
@@ -190,23 +197,21 @@ class SummarizationModule(BaseTransformer):
return
calculate_rouge
(
preds
,
target
)
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
pad_token_id
=
self
.
tokenizer
.
pad_token_id
source_ids
,
source_mask
,
y
=
Seq2SeqDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
t0
=
time
.
time
()
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source
_mask
,
batch
[
"
input_ids
"
]
,
attention_mask
=
batch
[
"attention
_mask
"
]
,
use_cache
=
True
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
)
gen_time
=
(
time
.
time
()
-
t0
)
/
source
_ids
.
shape
[
0
]
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
target
=
self
.
ids_to_clean_text
(
y
)
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input
_ids
"
]
.
shape
[
0
]
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
target
:
List
[
str
]
=
self
.
ids_to_clean_text
(
batch
[
"decoder_input_ids"
]
)
loss_tensors
=
self
.
_step
(
batch
)
base_metrics
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
rouge
:
Dict
=
self
.
calc_generative_metrics
(
preds
,
target
)
summ_len
=
np
.
mean
(
lmap
(
len
,
generated_ids
))
base_metrics
.
update
(
gen_time
=
gen_time
,
summ
_len
=
summ_len
,
preds
=
preds
,
target
=
target
,
**
rouge
)
base_metrics
.
update
(
gen_time
=
gen_time
,
gen
_len
=
summ_len
,
preds
=
preds
,
target
=
target
,
**
rouge
)
return
base_metrics
def
test_step
(
self
,
batch
,
batch_idx
):
...
...
examples/seq2seq/utils.py
View file @
dafa296c
...
...
@@ -128,12 +128,6 @@ class Seq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
@
staticmethod
def
trim_seq2seq_batch
(
batch
,
pad_token_id
)
->
tuple
:
y
=
trim_batch
(
batch
[
"decoder_input_ids"
],
pad_token_id
)
source_ids
,
source_mask
=
trim_batch
(
batch
[
"input_ids"
],
pad_token_id
,
attention_mask
=
batch
[
"attention_mask"
])
return
source_ids
,
source_mask
,
y
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment