"web/extensions/vscode:/vscode.git/clone" did not exist on "43f2505389a8cc37c3a76b676f2b31de65056d34"
Commit 506e5bb0 authored by tholor's avatar tholor
Browse files

add do_lower_case arg and adjust model saving for lm finetuning.

parent e485829a
...@@ -461,6 +461,9 @@ def main(): ...@@ -461,6 +461,9 @@ def main():
parser.add_argument("--on_memory", parser.add_argument("--on_memory",
action='store_true', action='store_true',
help="Whether to load train samples into memory or use disk") help="Whether to load train samples into memory or use disk")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
...@@ -612,12 +615,12 @@ def main(): ...@@ -612,12 +615,12 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if n_gpu > 1: if args.do_train:
torch.save(model.module.bert.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
else:
torch.save(model.bert.state_dict(), output_model_file)
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment