"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "c8f7bb990c072c77d8cfc8a6c883ba0c352a5671"
Commit 5c08c8c2 authored by Oliver Guhr's avatar Oliver Guhr
Browse files

adds the tokenizer + model config to the output

parent 784c0ed8
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
import os
import torch import torch
import logging import logging
import json import json
...@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler ...@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
...@@ -325,8 +327,13 @@ def main(): ...@@ -325,8 +327,13 @@ def main():
# Save a trained model # Save a trained model
logging.info("** ** * Saving fine-tuned model ** ** * ") logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = args.output_dir / "pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file)) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler ...@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
...@@ -614,9 +615,12 @@ def main(): ...@@ -614,9 +615,12 @@ def main():
# Save a trained model # Save a trained model
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
if args.do_train: if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment