Unverified Commit 9a21b506 authored by wlhgtc's avatar wlhgtc Committed by GitHub
Browse files

Fix eval ref miss in Chinese WWM. (#8115)



* ADD: add whole word mask proxy for both eng and chinese

* MOD: adjust format

* MOD: reformat code

* MOD: update import

* MOD: fix bug

* MOD: add import

* MOD: fix bug

* MOD: decouple code and update readme

* MOD: reformat code

* Update examples/language-modeling/README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change wwm to whole_word_mask

* reformat code

* reformat

* format

* Code quality

* ADD: update chinese ref readme

* MOD: small changes

* MOD: small changes2

* update readme

* fix eval ref file miss bug

* format file

* MOD: move ref code to contrib

* MOD: add delimeter check

* reformat code

* refomat code

* Update examples/language-modeling/README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent fdf893c4
...@@ -118,7 +118,7 @@ def main(args): ...@@ -118,7 +118,7 @@ def main(args):
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f: with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines() data = f.readlines()
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
ltp_tokenizer = LTP(args.ltp) # faster in GPU device ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert) bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
......
...@@ -63,7 +63,7 @@ python run_language_modeling.py \ ...@@ -63,7 +63,7 @@ python run_language_modeling.py \
--whole_word_mask --whole_word_mask
``` ```
For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level. For Chinese models, it's same with English model with only `--mlm`. If using whole-word masking, we need to generate a reference files, because it's char level.
**Q :** Why ref file ? **Q :** Why ref file ?
...@@ -76,15 +76,19 @@ So we need a ref file to tell model which pos of BERT original token should be a ...@@ -76,15 +76,19 @@ So we need a ref file to tell model which pos of BERT original token should be a
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE). **A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
They use LTP, so if we want to fine-tune their model, we need LTP. They use LTP, so if we want to fine-tune their model, we need LTP.
Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt.
You need to check to `3.2.0` for `run_chinese_ref.py`. And the code could be found in `examples/contrib`.
```bash ```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt export SAVE_PATH=/path/to/data/ref.txt
python chinese_ref.py \ python examples/contrib/run_chinese_ref.py \
--file_name=$TRAIN_FILE \ --file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE --ltp=$LTP_RESOURCE \
--bert=$BERT_RESOURCE \ --bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH --save_path=$SAVE_PATH
``` ```
......
...@@ -103,9 +103,13 @@ class DataTrainingArguments: ...@@ -103,9 +103,13 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
) )
chinese_ref_file: Optional[str] = field( train_ref_file: Optional[str] = field(
default=None, default=None,
metadata={"help": "An optional input ref data file for whole word mask in Chinees."}, metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
)
eval_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
) )
line_by_line: bool = field( line_by_line: bool = field(
default=False, default=False,
...@@ -148,16 +152,16 @@ def get_dataset( ...@@ -148,16 +152,16 @@ def get_dataset(
evaluate: bool = False, evaluate: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
def _dataset(file_path): def _dataset(file_path, ref_path=None):
if args.line_by_line: if args.line_by_line:
if args.chinese_ref_file is not None: if ref_path is not None:
if not args.whole_word_mask or not args.mlm: if not args.whole_word_mask or not args.mlm:
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask") raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
return LineByLineWithRefDataset( return LineByLineWithRefDataset(
tokenizer=tokenizer, tokenizer=tokenizer,
file_path=file_path, file_path=file_path,
block_size=args.block_size, block_size=args.block_size,
ref_path=args.chinese_ref_file, ref_path=ref_path,
) )
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
...@@ -171,11 +175,11 @@ def get_dataset( ...@@ -171,11 +175,11 @@ def get_dataset(
) )
if evaluate: if evaluate:
return _dataset(args.eval_data_file) return _dataset(args.eval_data_file, args.eval_ref_file)
elif args.train_data_files: elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)]) return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else: else:
return _dataset(args.train_data_file) return _dataset(args.train_data_file, args.train_ref_file)
def main(): def main():
......
...@@ -128,15 +128,17 @@ class LineByLineWithRefDataset(Dataset): ...@@ -128,15 +128,17 @@ class LineByLineWithRefDataset(Dataset):
logger.info("Creating features from dataset file at %s", file_path) logger.info("Creating features from dataset file at %s", file_path)
logger.info("Use ref segment results at %s", ref_path) logger.info("Use ref segment results at %s", ref_path)
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size) data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
# Get ref inf from file # Get ref inf from file
with open(ref_path, encoding="utf-8") as f: with open(ref_path, encoding="utf-8") as f:
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
assert len(data) == len(ref) assert len(data) == len(ref)
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
n = len(self.examples) n = len(self.examples)
for i in range(n): for i in range(n):
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long) self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment