Unverified Commit a0c62d24 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix training from scratch in new scripts (#8623)

parent 1e62e999
...@@ -313,9 +313,12 @@ def main(): ...@@ -313,9 +313,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
trainer.train( model_path = (
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
) )
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation # Evaluation
......
...@@ -354,9 +354,12 @@ def main(): ...@@ -354,9 +354,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
trainer.train( model_path = (
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
) )
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation # Evaluation
......
...@@ -302,9 +302,12 @@ def main(): ...@@ -302,9 +302,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
trainer.train( model_path = (
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
) )
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation # Evaluation
......
...@@ -344,9 +344,12 @@ def main(): ...@@ -344,9 +344,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
trainer.train( model_path = (
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
) )
trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation # Evaluation
......
...@@ -307,9 +307,18 @@ def main(): ...@@ -307,9 +307,18 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %}
trainer.train( trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
) )
{%- elif cookiecutter.can_train_from_scratch == "True" %}
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
{% endif %}
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation # Evaluation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment