"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "29a1c1b472674030d61a6753cf1e3772f5d7131f"
Unverified Commit 80f1a591 authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

updated with latest PL and Ray (#15653)

parent 7bc4a01c
...@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="max",
save_top_k=3, save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch. every_n_epochs=1, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
return checkpoint_callback return checkpoint_callback
......
...@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -254,7 +254,7 @@ class GenerativeQAModule(BaseTransformer):
def training_step(self, batch, batch_idx) -> Dict: def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch # tokens per batch
tgt_pad_token_id = ( tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id self.tokenizer.generator.pad_token_id
...@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule: ...@@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.") raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
# Connect to an existing Ray cluster. # Connect to an existing Ray cluster.
try: try:
ray.init(address=args.ray_address) ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError): except (ConnectionError, ValueError):
logger.warning( logger.warning(
"Connection to Ray cluster failed. Make sure a Ray" "Connection to Ray cluster failed. Make sure a Ray"
......
...@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule): ...@@ -266,6 +266,15 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--adafactor", action="store_true") parser.add_argument("--adafactor", action="store_true")
class InitCallback(pl.Callback):
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
def on_sanity_check_start(self, trainer, pl_module):
if (
trainer.is_global_zero and trainer.global_rank == 0
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.
class LoggingCallback(pl.Callback): class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module):
lr_scheduler = trainer.lr_schedulers[0]["scheduler"] lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
...@@ -368,19 +377,21 @@ def generic_train( ...@@ -368,19 +377,21 @@ def generic_train(
# TODO: remove with PyTorch 1.6 since pl uses native amp # TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16: if args.fp16:
train_params["precision"] = 16 train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level # train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1: if args.gpus > 1:
train_params["accelerator"] = "ddp" train_params["accelerator"] = "auto" # "ddp"
train_params["strategy"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params["devices"] = "auto"
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
weights_summary=None, weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback], callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback] + [InitCallback()],
plugins=[custom_ddp_plugin], # plugins=[custom_ddp_plugin],
logger=logger, logger=logger,
**train_params, **train_params,
) )
......
...@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3 ...@@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
datasets >= 1.0.1 datasets >= 1.0.1
psutil >= 5.7.0 psutil >= 5.7.0
torch >= 1.4.0 torch >= 1.4.0
ray >= 1.10.0
pytorch-lightning >= 1.5.10
transformers transformers
pytorch-lightning
GitPython GitPython
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment