Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fe326bd5
Unverified
Commit
fe326bd5
authored
Sep 25, 2020
by
Ola Piktus
Committed by
GitHub
Sep 25, 2020
Browse files
Remove dependency on examples/seq2seq from rag (#7395)
Co-authored-by:
Your Name
<
you@example.com
>
parent
ad39271a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
20 deletions
+157
-20
examples/rag/callbacks.py
examples/rag/callbacks.py
+87
-1
examples/rag/finetune.py
examples/rag/finetune.py
+8
-14
examples/rag/utils.py
examples/rag/utils.py
+62
-5
No files found.
examples/rag/callbacks.py
View file @
fe326bd5
import
logging
import
os
from
pathlib
import
Path
from
pytorch_lightning.callbacks
import
ModelCheckpoint
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
pytorch_lightning.callbacks
import
EarlyStopping
,
ModelCheckpoint
from
pytorch_lightning.utilities
import
rank_zero_only
from
utils
import
save_json
def
count_trainable_parameters
(
model
):
model_parameters
=
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
())
params
=
sum
([
np
.
prod
(
p
.
size
())
for
p
in
model_parameters
])
return
params
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -28,3 +41,76 @@ def get_checkpoint_callback(output_dir, metric):
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
return
checkpoint_callback
def
get_early_stopping_callback
(
metric
,
patience
):
return
EarlyStopping
(
monitor
=
f
"val_
{
metric
}
"
,
# does this need avg?
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
patience
=
patience
,
verbose
=
True
,
)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lrs
=
{
f
"lr_group_
{
i
}
"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
].
param_groups
)}
pl_module
.
logger
.
log_metrics
(
lrs
)
@
rank_zero_only
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
)
->
None
:
logger
.
info
(
f
"*****
{
type_path
}
results at step
{
trainer
.
global_step
:
05
d
}
*****"
)
metrics
=
trainer
.
callback_metrics
trainer
.
logger
.
log_metrics
({
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]})
# Log results
od
=
Path
(
pl_module
.
hparams
.
output_dir
)
if
type_path
==
"test"
:
results_file
=
od
/
"test_results.txt"
generations_file
=
od
/
"test_generations.txt"
else
:
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file
=
od
/
f
"
{
type_path
}
_results/
{
trainer
.
global_step
:
05
d
}
.txt"
generations_file
=
od
/
f
"
{
type_path
}
_generations/
{
trainer
.
global_step
:
05
d
}
.txt"
results_file
.
parent
.
mkdir
(
exist_ok
=
True
)
generations_file
.
parent
.
mkdir
(
exist_ok
=
True
)
with
open
(
results_file
,
"a+"
)
as
writer
:
for
key
in
sorted
(
metrics
):
if
key
in
[
"log"
,
"progress_bar"
,
"preds"
]:
continue
val
=
metrics
[
key
]
if
isinstance
(
val
,
torch
.
Tensor
):
val
=
val
.
item
()
msg
=
f
"
{
key
}
:
{
val
:.
6
f
}
\n
"
writer
.
write
(
msg
)
if
not
save_generations
:
return
if
"preds"
in
metrics
:
content
=
"
\n
"
.
join
(
metrics
[
"preds"
])
generations_file
.
open
(
"w+"
).
write
(
content
)
@
rank_zero_only
def
on_train_start
(
self
,
trainer
,
pl_module
):
try
:
npars
=
pl_module
.
model
.
model
.
num_parameters
()
except
AttributeError
:
npars
=
pl_module
.
model
.
num_parameters
()
n_trainable_pars
=
count_trainable_parameters
(
pl_module
)
# mp stands for million parameters
trainer
.
logger
.
log_metrics
({
"n_params"
:
npars
,
"mp"
:
npars
/
1e6
,
"grad_mp"
:
n_trainable_pars
/
1e6
})
@
rank_zero_only
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
save_json
(
pl_module
.
metrics
,
pl_module
.
metrics_save_path
)
return
self
.
_write_logs
(
trainer
,
pl_module
,
"test"
)
@
rank_zero_only
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
):
save_json
(
pl_module
.
metrics
,
pl_module
.
metrics_save_path
)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")
examples/rag/finetune.py
View file @
fe326bd5
...
...
@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
getcwd
()))
# noqa: E402 # noqa: E402 # isort:skip
from
examples.lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
# noqa: E402 # isort:skip
from
examples.rag.callbacks
import
get_checkpoint_callback
# noqa: E402 # isort:skip
from
examples.rag.callbacks
import
(
# noqa: E402 # isort:skip
get_checkpoint_callback
,
get_early_stopping_callback
,
Seq2SeqLoggingCallback
,
)
from
examples.rag.distributed_retriever
import
RagPyTorchDistributedRetriever
# noqa: E402 # isort:skip
from
examples.rag.utils
import
(
# noqa: E402 # isort:skip
Seq2SeqDataset
,
calculate_exact_match
,
is_rag_model
,
set_extra_model_params
,
)
from
examples.seq2seq.callbacks
import
Seq2SeqLoggingCallback
,
get_early_stopping_callback
# noqa: E402 # isort:skip
from
examples.seq2seq.utils
import
(
# noqa: E402 # isort:skip
flatten_list
,
get_git_info
,
is_rag_model
,
lmap
,
pickle_save
,
save_git_info
,
save_json
,
set_extra_model_params
,
Seq2SeqDataset
,
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer):
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
# TODO: assert earlier
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
shuffle
=
False
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer):
collate_fn
=
dataset
.
collate_fn
,
shuffle
=
shuffle
,
num_workers
=
self
.
num_workers
,
sampler
=
sampler
,
)
return
dataloader
...
...
@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer):
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--sortish_sampler"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--logger_name"
,
type
=
str
,
choices
=
[
"default"
,
"wandb"
,
"wandb_shared"
],
default
=
"default"
)
parser
.
add_argument
(
"--n_train"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
)
parser
.
add_argument
(
"--n_val"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
)
...
...
examples/rag/utils.py
View file @
fe326bd5
import
itertools
import
json
import
linecache
import
os
import
pickle
import
re
import
socket
import
string
from
collections
import
Counter
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
,
List
from
typing
import
Callable
,
Dict
,
Iterable
,
List
import
git
import
torch
from
torch.utils.data
import
Dataset
from
examples.seq2seq.utils
import
SortishSampler
,
trim_batch
from
transformers
import
BartTokenizer
,
RagTokenizer
,
T5Tokenizer
...
...
@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru
)
def
trim_batch
(
input_ids
,
pad_token_id
,
attention_mask
=
None
,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask
=
input_ids
.
ne
(
pad_token_id
).
any
(
dim
=
0
)
if
attention_mask
is
None
:
return
input_ids
[:,
keep_column_mask
]
else
:
return
(
input_ids
[:,
keep_column_mask
],
attention_mask
[:,
keep_column_mask
])
class
Seq2SeqDataset
(
Dataset
):
def
__init__
(
self
,
...
...
@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset):
}
return
batch
def
make_sortish_sampler
(
self
,
batch_size
):
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
logger
=
getLogger
(
__name__
)
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)]
def
save_git_info
(
folder_path
:
str
)
->
None
:
"""Save git information to output_dir/git_log.json"""
repo_infos
=
get_git_info
()
save_json
(
repo_infos
,
os
.
path
.
join
(
folder_path
,
"git_log.json"
))
def
save_json
(
content
,
path
,
indent
=
4
,
**
json_dump_kwargs
):
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
content
,
f
,
indent
=
indent
,
**
json_dump_kwargs
)
def
load_json
(
path
):
with
open
(
path
)
as
f
:
return
json
.
load
(
f
)
def
get_git_info
():
repo
=
git
.
Repo
(
search_parent_directories
=
True
)
repo_infos
=
{
"repo_id"
:
str
(
repo
),
"repo_sha"
:
str
(
repo
.
head
.
object
.
hexsha
),
"repo_branch"
:
str
(
repo
.
active_branch
),
"hostname"
:
str
(
socket
.
gethostname
()),
}
return
repo_infos
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
"""list(map(f, x))"""
return
list
(
map
(
f
,
x
))
def
pickle_save
(
obj
,
path
):
"""pickle.dump(obj, path)"""
with
open
(
path
,
"wb"
)
as
f
:
return
pickle
.
dump
(
obj
,
f
)
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment