Unverified Commit 5257818e authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

minor fixes in original RAG training (#12395)

parent e3f39a29
...@@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -36,7 +36,7 @@ def get_checkpoint_callback(output_dir, metric):
dirpath=output_dir, dirpath=output_dir,
filename=exp, filename=exp,
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="min", mode="max",
save_top_k=3, save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch. period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
......
...@@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule: ...@@ -532,8 +532,8 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise raise
# Create Ray actors only for rank 0. # Create Ray actors only for rank 0.
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and ( if ("LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0) and (
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0 "NODE_RANK" not in os.environ or int(os.environ["NODE_RANK"]) == 0
): ):
remote_cls = ray.remote(RayRetriever) remote_cls = ray.remote(RayRetriever)
named_actors = [ named_actors = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment