"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "cb4b94ef5a7873a937e58e0917cff68516da2332"
Unverified Commit 9dab39fe authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

seq2seq/run_eval.py can take decoder_start_token_id (#5949)

parent 5b193b39
...@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule): ...@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer): if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset self.dataset_class = MBartDataset
......
...@@ -30,6 +30,7 @@ def generate_summaries_or_translations( ...@@ -30,6 +30,7 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE, device: str = DEFAULT_DEVICE,
fp16=False, fp16=False,
task="summarization", task="summarization",
decoder_start_token_id=None,
**gen_kwargs, **gen_kwargs,
) -> None: ) -> None:
fout = Path(out_file).open("w", encoding="utf-8") fout = Path(out_file).open("w", encoding="utf-8")
...@@ -37,6 +38,8 @@ def generate_summaries_or_translations( ...@@ -37,6 +38,8 @@ def generate_summaries_or_translations(
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16: if fp16:
model = model.half() model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
...@@ -48,7 +51,12 @@ def generate_summaries_or_translations( ...@@ -48,7 +51,12 @@ def generate_summaries_or_translations(
batch = [model.config.prefix + text for text in batch] batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device) batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec: for hypothesis in dec:
fout.write(hypothesis + "\n") fout.write(hypothesis + "\n")
...@@ -66,6 +74,13 @@ def run_generate(): ...@@ -66,6 +74,13 @@ def run_generate():
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization") parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--decoder_start_token_id",
type=int,
default=None,
required=False,
help="decoder_start_token_id (otherwise will look at config)",
)
parser.add_argument( parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
) )
...@@ -83,6 +98,7 @@ def run_generate(): ...@@ -83,6 +98,7 @@ def run_generate():
device=args.device, device=args.device,
fp16=args.fp16, fp16=args.fp16,
task=args.task, task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
) )
if args.reference_path is None: if args.reference_path is None:
return return
......
...@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return encoded_inputs return encoded_inputs
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]: def batch_decode(
return [self.decode(seq, **kwargs) for seq in sequences] self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
return [
self.decode(
seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
for seq in sequences
]
def decode( def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment