Commit 6652f879 authored by hepj987's avatar hepj987
Browse files

更新文件

parent 8a802023
# 测试前准备
# 基于TF2框架的Bert训练
## 1.数据集准备
## 模型介绍
GLUE数据集下载https://pan.baidu.com/s/1tLd8opr08Nw5PzUBh7lXsQ
```
BERT的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。
```
分类使用其中的MNLI数据集
## 模型结构
提取码:fyvy
```
以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。而BERT利用MLM进行预训练并且采用深层的双向Transformer组件(单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token)来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。
```
问答数据:
## 模型下载
[train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
[bert-base-uncace(MNLI分类时使用此模型)](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)
[dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
[bert-large-uncase(squad问答使用此模型)](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip)
[evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
## 数据集准备
## 2.环境部署
MNLI分类数据集:[MNLI](https://dl.fbaipublicfiles.com/glue/data/MNLI.zip)
```
virtualenv -p python3 -system-site-packages venv_2
source venv_2/bin/activat
```
安装python依赖包
squad问答数据集:[train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)[dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
squad-v1.1 eval脚本:[evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
## 环境配置
推荐使用docker方式运行,提供[光源](https://www.sourcefind.cn/#/main-page)镜像,可以dockerpull拉取
```
pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install tensorflow-2.7.0-cp36-cp36m-linux_x86_64.whl
pip install horovod-0.21.3-cp36-cp36m-linux_x86_64.whl
pip install apex-0.1-cp36-cp36m-linux_x86_64.whl
docker pull image.sourcefind.cn:5000/dcu/admin/base/tensorflow:2.7.0-centos7.6-dtk-22.10.1-py37-latest
```
环境变量设置
## 安装依赖
```
module rm compiler/rocm/2.9
export ROCM_PATH=/public/home/hepj/job_env/apps/dtk-21.10.1
export HIP_PATH=${ROCM_PATH}/hip
export AMDGPU_TARGETS="gfx900;gfx906"
export PATH=${ROCM_PATH}/bin:${ROCM_PATH}/llvm/bin:${ROCM_PATH}/hcc/bin:${ROCM_PATH}/hip/bin:$PATH
pip install requirements.txt
```
## 3.MNLI分类测试
# MNLI分类测试
### 3.1单卡测试(单精度)
## 数据转化
#### 3.1.1数据转化
TF2.0版本读取数据方式与TF1.0不同,需要转化为tf_record格式
TF2.0版本读取数据需要转化为tf_record格式
```
python ../data/create_finetuning_data.py \
python create_finetuning_data.py \
--input_data_dir=/public/home/hepj/data/MNLI \
--vocab_file=/public/home/hepj/model/tf2.7.0_Bert/pre_tf2x/vocab.txt \
--train_data_output_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/train.tf_record \
......@@ -60,9 +57,19 @@ python ../data/create_finetuning_data.py \
--fine_tuning_task_type=classification
--max_seq_length=32 \
--classification_task_name=MNLI
#参数说明
--input_data_dir 训练数据路径
--vocab_file vocab文件路径
--train_data_output_path 训练数据保存路径
--eval_data_output_path 验证数据保存路径
--fine_tuning_task_type fine-tune任务类型
--do_lower_case 是否进行lower
--max_seq_length 最大句子长度
--classification_task_name 分类任务名
```
#### 3.1.2 模型转化
## 模型转化
TF2.7.2与TF1.15.0模型存储、读取格式不同,官网给出的Bert一般是基于TF1.0的模型需要进行模型转化
......@@ -71,93 +78,45 @@ python3 tf2_encoder_checkpoint_converter.py \
--bert_config_file /public/home/hepj/model_source/uncased_L-12_H-768_A-12/bert_config.json \
--checkpoint_to_convert /public/home/hepjl/model_source/uncased_L-12_H-768_A-12/bert_model.ckpt \
--converted_checkpoint_path pre_tf2x/
```
#### 3.1.3 bert_class.sh
```
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_FIND_MODE=3
export MIOPEN_ENABLE_LOGGING_CMD=1
export ROCBLAS_LAYER=3
module unload compiler/rocm/2.9
echo "MIOPEN_FIND_MODE=$MIOPEN_FIND_MODE"
lrank=$OMPI_COMM_WORLD_LOCAL_RANK
comm_rank=$OMPI_COMM_WORLD_RANK
comm_size=$OMPI_COMM_WORLD_SIZE
python3 run_classifier.py \
--mode=train_and_eval \
--input_meta_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/meta_data \
--train_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/train.tf_record \
--eval_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/eval.tf_record \
--bert_config_file=/public/home/hepj/model/tf2.7.0_Bert/pre_tf2x/bert_config.json \
--init_checkpoint=/public/home/hepj/model/tf2.7.0_Bert/pre_tf2x/bert_model.ckpt \
--train_batch_size= 320 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=/public/home/hepj/model/tf2/out1 \
--distribution_strategy=mirrored
#参数说明
--bert_config_file bert模型config文件
--checkpoint_to_convert 需要转换的模型路径
--converted_checkpoint_path 转换后模型路径
```
#### 3.1.4 运行
sh bert_class.sh
### 3.2 四卡测试(单精度)
#### 3.2.1. 数据转化
与单卡相同(3.1.1)
#### 3.2.2. 模型转化
与单卡相同(3.1.2)
#### 3.2.3. bert_class4.sh
## 单卡运行
```
#这里的--train_batch_size为global train_batch_size
#使用mpirun的方式启动多卡存在一些问题
export HIP_VISIBLE_DEVICES=0,1,2,3
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_FIND_MODE=3
module unload compiler/rocm/2.9
echo "MIOPEN_FIND_MODE=$MIOPEN_FIND_MODE"
lrank=$OMPI_COMM_WORLD_LOCAL_RANK
comm_rank=$OMPI_COMM_WORLD_RANK
comm_size=$OMPI_COMM_WORLD_SIZE
python3 run_classifier.py \
--mode=train_and_eval \
--input_meta_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/meta_data \
--train_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/train.tf_record \
--eval_data_path=/public/home/hepj/model/tf2.7.0_Bert/MNLI/eval.tf_record \
--bert_config_file=/public/home/hepj/model/tf2.7.0_Bert/pre_tf2x/bert_config.json \
--init_checkpoint=/public/home/hepj/model/tf2.7.0_Bert/pre_tf2x/bert_model.ckpt \
--train_batch_size=1280 \
--eval_batch_size=32 \
--steps_per_loop=10 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--num_gpus=4 \
--model_dir=/public/home/hepj/outdir/tf2/class4 \
--distribution_strategy=mirrored
```
#### 3.2.4. 运行
sh bert_class.sh
#参数说明
--mode 模型模式train_and_eval、export_only、predict
--input_meta_data_path 用于训练和验证的元数据
--train_data_path 训练数据路径
--eval_data_path 验证数据路径
--bert_config_file bert模型config文件
--init_checkpoint 初始化模型路径
--train_batch_size 训练批大小
--eval_batch_size 验证批大小
--steps_per_loop 打印log间隔
--learning_rate 学习率
--num_train_epochs 训练epoch数
--model_dir 模型保存文件夹
--distribution_strategy 分布式策略
--num_gpus 使用gpu数量
```
## 多卡运行
```
sh bert_class4.sh
```
# SQUAD1.1问答测试
### 数据转化
## 4. SQUAD1.1问答测试
### 4.1. 单卡测试(单精度)
#### 4.1.1. 数据转化
TF2.0版本读取数据需要转化为tf_record格式
```
python3 create_finetuning_data.py \
......@@ -169,96 +128,68 @@ python3 create_finetuning_data.py \
--fine_tuning_task_type=squad \
--do_lower_case=Flase \
--max_seq_length=384
#参数说明
--squad_data_file 训练文件路径
--vocab_file vocab文件路径
--train_data_output_path 训练数据保存路径
--eval_data_output_path 验证数据保存路径
--fine_tuning_task_type fine-tune任务类型
--do_lower_case 是否进行lower
--max_seq_length 最大句子长度
```
#### 4.1.2. 模型转化
### 模型转化
```
python3 tf2_encoder_checkpoint_converter.py \
--bert_config_file /public/home/hepj/model/model_source/uncased_L-24_H-1024_A-16/bert_config.json \
--checkpoint_to_convert /public/home/hepj/model/model_sourceuncased_L-24_H-1024_A-16/bert_model.ckpt \
--converted_checkpoint_path /public/home/hepj/model_source/bert-large-uncased-TF2/
```
#### 4.1.3. bert_squad.sh
```
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_FIND_MODE=3
export MIOPEN_ENABLE_LOGGING_CMD=1
export ROCBLAS_LAYER=3
module unload compiler/rocm/2.9
echo "MIOPEN_FIND_MODE=$MIOPEN_FIND_MODE"
lrank=$OMPI_COMM_WORLD_LOCAL_RANK
comm_rank=$OMPI_COMM_WORLD_RANK
comm_size=$OMPI_COMM_WORLD_SIZE
python3 run_squad_xuan.py \
--mode=train_and_eval \
--vocab_file=/public/home/hepj/model/model_source/uncased_L-24_H-1024_A-16/vocab.txt \
--bert_config_file=/public/home/hepj/model/model_source/uncased_L-24_H-1024_A-16/bert_config.json \
--input_meta_data_path=/public/home/hepj/model/tf2.7.0_Bert/squad1.1/meta_data \
--train_data_path=/public/home/hepj/model/tf2.7.0_Bert/squad1.1/train.tf_record \
--predict_file=/public/home/hepj/model/model_source/sq1.1/dev-v1.1.json \
--init_checkpoint=/public/home/hepj/model_source/bert-large-uncased-TF2/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=2e-5 \
--log_steps=1 \
--num_gpus=1 \
--distribution_strategy=mirrored \
--model_dir=/public/home/hepj/model/tf2/squad1 \
--run_eagerly=False
#参数说明
--bert_config_file bert模型config文件
--checkpoint_to_convert 需要转换的模型路径
--converted_checkpoint_path 转换后模型路径
```
#### 4.1.4. 运行
### 单卡运行
```
sh bert_squad.sh
```
### 4.2. 四卡测试(单精度)
#### 4.2.1. 数据转化
与单卡相同(4.1.1)
#### 4.2.2. 模型转化
与单卡相同(4.1.2)
#### 4.2.3. bert_squad4.sh
#参数说明
--mode 模型模式train_and_eval、export_only、predict
--vocab_file vocab文件路径
--input_meta_data_path 用于训练和验证的元数据
--train_data_path 训练数据路径
--eval_data_path 验证数据路径
--bert_config_file bert模型config文件
--init_checkpoint 初始化模型路径
--train_batch_size 训练批大小
--predict_file 预测文件路径
--eval_batch_size 验证批大小
--steps_per_loop 打印log间隔
--learning_rate 学习率
--num_train_epochs 训练epoch数
--model_dir 模型保存文件夹
--distribution_strategy 分布式策略
--num_gpus 使用gpu数量
```
### 多卡运行
```
#这里的--train_batch_size为global train_batch_size
#使用mpirun的方式启动多卡存在一些问题
export HSA_FORCE_FINE_GRAIN_PCIE=1
export MIOPEN_FIND_MODE=3
module unload compiler/rocm/2.9
echo "MIOPEN_FIND_MODE=$MIOPEN_FIND_MODE"
export HIP_VISIBLE_DEVICES=0,1,2,3
python3 run_squad_xuan.py \
--mode=train_and_eval \
--vocab_file=/public/home/hepj/model/model_source/uncased_L-24_H-1024_A-16/vocab.txt \
--bert_config_file=/public/home/hepj/model/model_source/uncased_L-24_H-1024_A-16/bert_config.json \
--input_meta_data_path=/public/home/hepj/model/tf2.7.0_Bert/squad1.1/meta_data \
--train_data_path=/public/home/hepj/model/tf2.7.0_Bert/squad1.1/train.tf_record \
--predict_file=/public/home/hepj/model/model_source/sq1.1/dev-v1.1.json \
--init_checkpoint=/public/home/hepj/model_source/bert-large-uncased-TF2/bert_model.ckpt \
--train_batch_size=16 \
--predict_batch_size=4 \
--learning_rate=2e-5 \
--log_steps=1 \
--num_gpus=4 \
--distribution_strategy=mirrored \
--model_dir=/public/home/hepj/outdir/tf2/squad4 \
--run_eagerly=False
sh bert_squad4.sh
```
#### 4.2.4. 运行
## 模型精度
```
sh bert_squad4.sh
```
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/bert-tf2
## 参考
https://github.com/tensorflow/models/tree/v2.3.0/official/nlp
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"fine_tuning_task_type", "classification",
["classification", "regression", "squad", "retrieval"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
# BERT classification specific flags.
flags.DEFINE_string(
"input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# XNLI task specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data "
"of all languages will be used for training.")
# PAWS-X task specific flag.
flags.DEFINE_string(
"pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# BERT Squad task specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
"take between chunks.")
flags.DEFINE_integer(
"max_query_length", 64,
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"train_data_output_path", None,
"The path in which generated training input data will be written as tf"
" records.")
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated evaluation input data will be written as tf"
" records.")
flags.DEFINE_string(
"test_data_output_path", None,
"The path in which generated test input data will be written as tf"
" records. If None, do not generate test data. Must be a pattern template"
" as test_{}.tfrecords if processor has language specific test data.")
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length)
else:
processors = {
"cola":
classifier_data_lib.ColaProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
"paws-x":
functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language),
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name](process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length)
def generate_regression_dataset():
"""Generates regression dataset and returns input meta data."""
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length)
else:
raise ValueError("No data processor found for the given regression task.")
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
if FLAGS.tokenizer_impl == "word_piece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processors = {
"bucc": sentence_retrieval_lib.BuccProcessor,
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
}
task_name = FLAGS.retrieval_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file:
raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.")
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
if not FLAGS.sp_model_file:
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type != "retrieval":
flags.mark_flag_as_required("train_data_output_path")
if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset()
elif FLAGS.fine_tuning_task_type == "regression":
input_meta_data = generate_regression_dataset()
elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
else:
input_meta_data = generate_squad_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
writer.write(json.dumps(input_meta_data, indent=4) + "\n")
if __name__ == "__main__":
flags.mark_flag_as_required("meta_data_file_path")
app.run(main)
......@@ -28,7 +28,6 @@ from absl import logging
import gin
import tensorflow as tf
import sys
sys.path.append("/public/home/xuanbaby/DL-TensorFlow/models_r2.3.0")
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment