"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "61e7767ca67bf12e9fd1149e6322c74ab116a6c1"
Commit 7e7fc53d authored by LysandreJik's avatar LysandreJik Committed by Lysandre Debut
Browse files

Fixing run_glue example with RoBERTa

parent ab052806
...@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
sep_token=tokenizer.sep_token, sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token], pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
) )
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
......
...@@ -425,9 +425,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -425,9 +425,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# Account for [CLS], [SEP], [SEP] with "- 3" # Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else: else:
# Account for [CLS] and [SEP] with "- 2" # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
if len(tokens_a) > max_seq_length - 2: special_tokens_count = 3 if sep_token_extra else 2
tokens_a = tokens_a[:(max_seq_length - 2)] if len(tokens_a) > max_seq_length - special_tokens_count:
tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
# The convention in BERT is: # The convention in BERT is:
# (a) For sequence pairs: # (a) For sequence pairs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment