Unverified Commit ba47efbf authored by Phuc Van Phan's avatar Phuc Van Phan Committed by GitHub
Browse files

docs: change assert to raise and some small docs (#26232)

* docs: change assert to raise and some small docs

* docs: add rule and some document

* fix: fix bug

* fix: fix bug

* chorse: revert logging

* chorse: revert
parent 375b4e09
...@@ -246,13 +246,16 @@ def parse_args(): ...@@ -246,13 +246,16 @@ def parse_args():
else: else:
if args.train_file is not None: if args.train_file is not None:
extension = args.train_file.split(".")[-1] extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." if extension not in ["csv", "json", "txt"]:
raise ValueError("`train_file` should be a csv, json or txt file.")
if args.validation_file is not None: if args.validation_file is not None:
extension = args.validation_file.split(".")[-1] extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if extension not in ["csv", "json", "txt"]:
raise ValueError("`validation_file` should be a csv, json or txt file.")
if args.push_to_hub: if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." if args.output_dir is None:
raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.")
return args return args
......
...@@ -261,7 +261,8 @@ def parse_args(): ...@@ -261,7 +261,8 @@ def parse_args():
raise ValueError("`validation_file` should be a csv, json or txt file.") raise ValueError("`validation_file` should be a csv, json or txt file.")
if args.push_to_hub: if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." if args.output_dir is None:
raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.")
return args return args
...@@ -694,7 +695,7 @@ def main(): ...@@ -694,7 +695,7 @@ def main():
except OverflowError: except OverflowError:
perplexity = float("inf") perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}") logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
if args.with_tracking: if args.with_tracking:
accelerator.log( accelerator.log(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment