Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
80f1a591
Unverified
Commit
80f1a591
authored
Feb 16, 2022
by
Shamane Siri
Committed by
GitHub
Feb 15, 2022
Browse files
updated with latest PL and Ray (#15653)
parent
7bc4a01c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
9 deletions
+21
-9
examples/research_projects/rag/callbacks_rag.py
examples/research_projects/rag/callbacks_rag.py
+1
-1
examples/research_projects/rag/finetune_rag.py
examples/research_projects/rag/finetune_rag.py
+2
-2
examples/research_projects/rag/lightning_base.py
examples/research_projects/rag/lightning_base.py
+15
-4
examples/research_projects/rag/requirements.txt
examples/research_projects/rag/requirements.txt
+3
-2
No files found.
examples/research_projects/rag/callbacks_rag.py
View file @
80f1a591
...
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
...
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor
=
f
"val_
{
metric
}
"
,
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
mode
=
"max"
,
save_top_k
=
3
,
save_top_k
=
3
,
period
=
1
,
# maybe save a checkpoint every time val is run, not just end of epoch.
every_n_epochs
=
1
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
)
return
checkpoint_callback
return
checkpoint_callback
...
...
examples/research_projects/rag/finetune_rag.py
View file @
80f1a591
...
@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
...
@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
def
training_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
def
training_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
loss_tensors
=
self
.
_step
(
batch
)
loss_tensors
=
self
.
_step
(
batch
)
logs
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
logs
=
{
name
:
loss
.
detach
()
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
# tokens per batch
# tokens per batch
tgt_pad_token_id
=
(
tgt_pad_token_id
=
(
self
.
tokenizer
.
generator
.
pad_token_id
self
.
tokenizer
.
generator
.
pad_token_id
...
@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
...
@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise
RuntimeError
(
"Please install Ray to use the Ray "
"distributed retriever."
)
raise
RuntimeError
(
"Please install Ray to use the Ray "
"distributed retriever."
)
# Connect to an existing Ray cluster.
# Connect to an existing Ray cluster.
try
:
try
:
ray
.
init
(
address
=
args
.
ray_address
)
ray
.
init
(
address
=
args
.
ray_address
,
namespace
=
"rag"
)
except
(
ConnectionError
,
ValueError
):
except
(
ConnectionError
,
ValueError
):
logger
.
warning
(
logger
.
warning
(
"Connection to Ray cluster failed. Make sure a Ray"
"Connection to Ray cluster failed. Make sure a Ray"
...
...
examples/research_projects/rag/lightning_base.py
View file @
80f1a591
...
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
...
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
parser
.
add_argument
(
"--adafactor"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--adafactor"
,
action
=
"store_true"
)
class
InitCallback
(
pl
.
Callback
):
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
def
on_sanity_check_start
(
self
,
trainer
,
pl_module
):
if
(
trainer
.
is_global_zero
and
trainer
.
global_rank
==
0
):
# we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
pl_module
.
model
.
rag
.
retriever
.
init_retrieval
()
# better to use hook functions.
class
LoggingCallback
(
pl
.
Callback
):
class
LoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lr_scheduler
=
trainer
.
lr_schedulers
[
0
][
"scheduler"
]
lr_scheduler
=
trainer
.
lr_schedulers
[
0
][
"scheduler"
]
...
@@ -368,19 +377,21 @@ def generic_train(
...
@@ -368,19 +377,21 @@ def generic_train(
# TODO: remove with PyTorch 1.6 since pl uses native amp
# TODO: remove with PyTorch 1.6 since pl uses native amp
if
args
.
fp16
:
if
args
.
fp16
:
train_params
[
"precision"
]
=
16
train_params
[
"precision"
]
=
16
train_params
[
"amp_level"
]
=
args
.
fp16_opt_level
#
train_params["amp_level"] = args.fp16_opt_level
if
args
.
gpus
>
1
:
if
args
.
gpus
>
1
:
train_params
[
"accelerator"
]
=
"ddp"
train_params
[
"accelerator"
]
=
"auto"
# "ddp"
train_params
[
"strategy"
]
=
"ddp"
train_params
[
"accumulate_grad_batches"
]
=
args
.
accumulate_grad_batches
train_params
[
"accumulate_grad_batches"
]
=
args
.
accumulate_grad_batches
train_params
[
"profiler"
]
=
None
# extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params
[
"profiler"
]
=
None
# extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params
[
"devices"
]
=
"auto"
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
weights_summary
=
None
,
weights_summary
=
None
,
callbacks
=
[
logging_callback
]
+
extra_callbacks
+
[
checkpoint_callback
],
callbacks
=
[
logging_callback
]
+
extra_callbacks
+
[
checkpoint_callback
]
+
[
InitCallback
()]
,
plugins
=
[
custom_ddp_plugin
],
#
plugins=[custom_ddp_plugin],
logger
=
logger
,
logger
=
logger
,
**
train_params
,
**
train_params
,
)
)
...
...
examples/research_projects/rag/requirements.txt
View file @
80f1a591
...
@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
...
@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
datasets >= 1.0.1
datasets >= 1.0.1
psutil >= 5.7.0
psutil >= 5.7.0
torch >= 1.4.0
torch >= 1.4.0
ray >= 1.10.0
pytorch-lightning >= 1.5.10
transformers
transformers
pytorch-lightning
GitPython
GitPython
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment