Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
80f1a591
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "29a1c1b472674030d61a6753cf1e3772f5d7131f"
Unverified
Commit
80f1a591
authored
Feb 16, 2022
by
Shamane Siri
Committed by
GitHub
Feb 15, 2022
Browse files
updated with latest PL and Ray (#15653)
parent
7bc4a01c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
9 deletions
+21
-9
examples/research_projects/rag/callbacks_rag.py
examples/research_projects/rag/callbacks_rag.py
+1
-1
examples/research_projects/rag/finetune_rag.py
examples/research_projects/rag/finetune_rag.py
+2
-2
examples/research_projects/rag/lightning_base.py
examples/research_projects/rag/lightning_base.py
+15
-4
examples/research_projects/rag/requirements.txt
examples/research_projects/rag/requirements.txt
+3
-2
No files found.
examples/research_projects/rag/callbacks_rag.py
View file @
80f1a591
...
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
...
@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor
=
f
"val_
{
metric
}
"
,
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"max"
,
mode
=
"max"
,
save_top_k
=
3
,
save_top_k
=
3
,
period
=
1
,
# maybe save a checkpoint every time val is run, not just end of epoch.
every_n_epochs
=
1
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
)
return
checkpoint_callback
return
checkpoint_callback
...
...
examples/research_projects/rag/finetune_rag.py
View file @
80f1a591
...
@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
...
@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
def
training_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
def
training_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
loss_tensors
=
self
.
_step
(
batch
)
loss_tensors
=
self
.
_step
(
batch
)
logs
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
logs
=
{
name
:
loss
.
detach
()
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
# tokens per batch
# tokens per batch
tgt_pad_token_id
=
(
tgt_pad_token_id
=
(
self
.
tokenizer
.
generator
.
pad_token_id
self
.
tokenizer
.
generator
.
pad_token_id
...
@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
...
@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise
RuntimeError
(
"Please install Ray to use the Ray "
"distributed retriever."
)
raise
RuntimeError
(
"Please install Ray to use the Ray "
"distributed retriever."
)
# Connect to an existing Ray cluster.
# Connect to an existing Ray cluster.
try
:
try
:
ray
.
init
(
address
=
args
.
ray_address
)
ray
.
init
(
address
=
args
.
ray_address
,
namespace
=
"rag"
)
except
(
ConnectionError
,
ValueError
):
except
(
ConnectionError
,
ValueError
):
logger
.
warning
(
logger
.
warning
(
"Connection to Ray cluster failed. Make sure a Ray"
"Connection to Ray cluster failed. Make sure a Ray"
...
...
examples/research_projects/rag/lightning_base.py
View file @
80f1a591
...
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
...
@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
parser
.
add_argument
(
"--adafactor"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--adafactor"
,
action
=
"store_true"
)
class
InitCallback
(
pl
.
Callback
):
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
def
on_sanity_check_start
(
self
,
trainer
,
pl_module
):
if
(
trainer
.
is_global_zero
and
trainer
.
global_rank
==
0
):
# we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
pl_module
.
model
.
rag
.
retriever
.
init_retrieval
()
# better to use hook functions.
class
LoggingCallback
(
pl
.
Callback
):
class
LoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lr_scheduler
=
trainer
.
lr_schedulers
[
0
][
"scheduler"
]
lr_scheduler
=
trainer
.
lr_schedulers
[
0
][
"scheduler"
]
...
@@ -368,19 +377,21 @@ def generic_train(
...
@@ -368,19 +377,21 @@ def generic_train(
# TODO: remove with PyTorch 1.6 since pl uses native amp
# TODO: remove with PyTorch 1.6 since pl uses native amp
if
args
.
fp16
:
if
args
.
fp16
:
train_params
[
"precision"
]
=
16
train_params
[
"precision"
]
=
16
train_params
[
"amp_level"
]
=
args
.
fp16_opt_level
#
train_params["amp_level"] = args.fp16_opt_level
if
args
.
gpus
>
1
:
if
args
.
gpus
>
1
:
train_params
[
"accelerator"
]
=
"ddp"
train_params
[
"accelerator"
]
=
"auto"
# "ddp"
train_params
[
"strategy"
]
=
"ddp"
train_params
[
"accumulate_grad_batches"
]
=
args
.
accumulate_grad_batches
train_params
[
"accumulate_grad_batches"
]
=
args
.
accumulate_grad_batches
train_params
[
"profiler"
]
=
None
# extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params
[
"profiler"
]
=
None
# extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params
[
"devices"
]
=
"auto"
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
weights_summary
=
None
,
weights_summary
=
None
,
callbacks
=
[
logging_callback
]
+
extra_callbacks
+
[
checkpoint_callback
],
callbacks
=
[
logging_callback
]
+
extra_callbacks
+
[
checkpoint_callback
]
+
[
InitCallback
()]
,
plugins
=
[
custom_ddp_plugin
],
#
plugins=[custom_ddp_plugin],
logger
=
logger
,
logger
=
logger
,
**
train_params
,
**
train_params
,
)
)
...
...
examples/research_projects/rag/requirements.txt
View file @
80f1a591
...
@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
...
@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
datasets >= 1.0.1
datasets >= 1.0.1
psutil >= 5.7.0
psutil >= 5.7.0
torch >= 1.4.0
torch >= 1.4.0
ray >= 1.10.0
pytorch-lightning >= 1.5.10
transformers
transformers
pytorch-lightning
GitPython
GitPython
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment