"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "2d105066df97a283fa155e39e0cc34ebbe58f55f"
Commit 47a06d88 authored by Rémi Louf's avatar Rémi Louf
Browse files

use two different tokenizers for storyand summary

parent bfb9b540
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import BertTokenizer from transformers import AutoTokenizer, Model2Model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,7 +57,7 @@ class TextDataset(Dataset): ...@@ -57,7 +57,7 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init_(self, tokenizer, data_dir="", block_size=512): def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
...@@ -90,15 +90,13 @@ class TextDataset(Dataset): ...@@ -90,15 +90,13 @@ class TextDataset(Dataset):
except IndexError: # skip ill-formed stories except IndexError: # skip ill-formed stories
continue continue
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story))
summary_seq = _fit_to_block_size(summary, block_size)
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
story_seq = _fit_to_block_size(story, block_size) story_seq = _fit_to_block_size(story, block_size)
self.examples.append( summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary))
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) summary_seq = _fit_to_block_size(summary, block_size)
)
self.examples.append((story_seq, summary_seq))
logger.info("Saving features into cache file %s", cached_features_file) logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink: with open(cached_features_file, "wb") as sink:
...@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size): ...@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
return sequence.extend([-1] * [block_size - len(sequence)]) return sequence.extend([-1] * [block_size - len(sequence)])
def load_and_cache_examples(args, tokenizer): def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
dataset = TextDataset(tokenizer, file_path=args.data_dir) dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
return dataset return dataset
...@@ -205,14 +203,35 @@ def main(): ...@@ -205,14 +203,35 @@ def main():
# Optional parameters # Optional parameters
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--decoder_name_or_path",
default="bert-base-cased", default="bert-base-cased",
type=str, type=str,
help="The model checkpoint for weights initialization.", help="The model checkpoint to initialize the decoder's weights with.",
)
parser.add_argument(
"--decoder_type",
default="bert",
type=str,
help="The decoder architecture to be fine-tuned.",
)
parser.add_argument(
"--encoder_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the encoder's weights with.",
)
parser.add_argument(
"--encoder_type",
default="bert",
type=str,
help="The encoder architecture to be fine-tuned.",
) )
parser.add_argument("--seed", default=42, type=int) parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args() args = parser.parse_args()
if args.encoder_type != 'bert' or args.decoder_type != 'bert':
raise ValueError("Only the BERT architecture is currently supported for seq2seq.")
# Set up training device # Set up training device
# device = torch.device("cpu") # device = torch.device("cpu")
...@@ -220,16 +239,15 @@ def main(): ...@@ -220,16 +239,15 @@ def main():
set_seed(args) set_seed(args)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
tokenizer_class = BertTokenizer encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
# config = config_class.from_pretrained(args.model_name_or_path) decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path)
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
# model.to(device) # model.to(device)
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
# Training # Training
_ = load_and_cache_examples(args, tokenizer) source, target = load_and_cache_examples(args, tokenizer)
# global_step, tr_loss = train(args, train_dataset, model, tokenizer) # global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment