Commit 392df446 authored by Rayyyyy's avatar Rayyyyy
Browse files

Modify train codes and README

parent 0274039a
...@@ -68,11 +68,9 @@ export HF_ENDPOINT=https://hf-mirror.com ...@@ -68,11 +68,9 @@ export HF_ENDPOINT=https://hf-mirror.com
``` ```
## 数据集 ## 数据集
使用来自多个数据集的結合来微调模型,句子对的总数超过10亿个句子。对每个数据集进行抽样,给出一个加权概率,该概率在data_config.json文件中详细说明。 使用来自多个数据集的結合来微调模型,句子对的总数超过10亿个句子。因数据较多,这里仅用[stsbenchmark](https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz)[Simple Wikipedia Version 1.0](https://cs.pomona.edu/~dkauchak/simplification/)数据集进行展示,数据集已在`datasets`中提供,详细数据请参考[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型中的Model card。
因数据较多,这里仅用[Simple Wikipedia Version 1.0](https://cs.pomona.edu/~dkauchak/simplification/)数据集进行展示,数据集已在 datasets/simple_wikipedia_v1 中提供,详细数据请参考[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型中的Model card。
bert-base-uncased
数据集的目录结构如下: 数据集的目录结构如下:
``` ```
├── datasets ├── datasets
...@@ -86,21 +84,31 @@ export HF_ENDPOINT=https://hf-mirror.com ...@@ -86,21 +84,31 @@ export HF_ENDPOINT=https://hf-mirror.com
推理数据需要转换成txt格式,参考[gen_simple_wikipedia_v1.py](./gen_simple_wikipedia_v1.py)文件,生成`simple_wiki_pair.txt` 推理数据需要转换成txt格式,参考[gen_simple_wikipedia_v1.py](./gen_simple_wikipedia_v1.py)文件,生成`simple_wiki_pair.txt`
## 训练 ## 训练
默认使用预训练模型[MiniLM-L6-H384-uncased](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased)进行finetune训练,有关预训练程序的详细信息,请参阅 model card。 - **训练**默认模型[bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)
- **微调**默认模型[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
### 单机多卡 ### 单机多卡
- 训练
```bash
bash train.sh
```
- 微调
```bash ```bash
bash finetune.sh bash finetune.sh
``` ```
### 单机单卡 ### 单机单卡
- 训练
```bash ```bash
python finetune.py python training_stsbenchmark.py
```
- 微调
```bash
python training_stsbenchmark_continue_training.py
``` ```
## 推理 ## 推理
1. 预训练模型下载[pretrained models](https://www.sbert.net/docs/pretrained_models.html), 当前默认为[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型; 1. 预训练模型下载[pretrained models](https://www.sbert.net/docs/pretrained_models.html), 当前默认为[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型;
2. 执行以下命令,测试数据默认为`./datasets/simple_wikipedia_v1/simple_wiki_pair.txt`,可修改`--data_path`参数为其他待测文件地址,文件内容格式请参考[simple_wiki_pair.txt](./datasets/simple_wikipedia_v1/simple_wiki_pair.txt) 2. 执行以下命令,测试数据默认为`./datasets/simple_wikipedia_v1/simple_wiki_pair.txt`,可修改`--data_path`参数为其他待测文件地址,文件内容格式请参考[simple_wiki_pair.txt](./datasets/simple_wikipedia_v1/simple_wiki_pair.txt)
```bash ```bash
......
...@@ -7,4 +7,4 @@ export USE_MIOPEN_BATCHNORM=1 ...@@ -7,4 +7,4 @@ export USE_MIOPEN_BATCHNORM=1
echo "Training start ..." echo "Training start ..."
torchrun --nproc_per_node=4 finetune.py --train_batch_size 16 --num_epochs 5 torchrun --nproc_per_node=4 training_stsbenchmark_continue_training.py --train_batch_size 16 --num_epochs 5
...@@ -3,7 +3,7 @@ modelCode=656 ...@@ -3,7 +3,7 @@ modelCode=656
# 模型名称 # 模型名称
modelName=sentence-bert_pytorch modelName=sentence-bert_pytorch
# 模型描述 # 模型描述
modelDescription=一种对预训练BERT网络的改进,它使用连体和三重网络结构来获得语义上有意义的句子嵌入,可以使用余弦相似度进行比较 modelDescription=一种对预训练BERT网络的改进,它使用连体和三重网络结构来获得语义上有意义的句子嵌入,可以使用余弦相似度进行比较
# 应用场景 # 应用场景
appScenario=推理,训练,NLP,教育,网安,政府 appScenario=推理,训练,NLP,教育,网安,政府
# 框架类型 # 框架类型
......
#!/bin/bash
echo "Export params ..."
export HIP_VISIBLE_DEVICES=0,1,2,3 # 自行修改为训练的卡号和数量
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1
echo "Training start ..."
torchrun --nproc_per_node=4 training_stsbenchmark.py --train_batch_size 16 --num_epochs 5
"""
This examples trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) for the STSbenchmark from scratch. It generates sentence embeddings
that can be compared using cosine-similarity to measure the similarity.
"""
import math
import sys
import os
import gzip
import csv
import logging
import argparse
from datetime import datetime
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import STSBenchmarkDataReader, InputExample
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### params
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='datasets/stsbenchmark.tsv.gz', help='Input txt path')
parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--model_name_or_path', type=str, default="bert-base-uncased")
parser.add_argument('--save_root_path', type=str, default="output", help='Model output folder')
parser.add_argument('--lr', default=2e-05)
args = parser.parse_args()
# Check if dataset exsist. If not, download and extract it
sts_dataset_path = args.data_path
if not os.path.exists(sts_dataset_path):
util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
#You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = args.model_name_or_path
# Read the dataset
train_batch_size = args.train_batch_size
num_epochs = args.num_epochs
model_save_path = args.save_root_path + "/training_stsbenchmark_" + model_name.replace("/", "-") + '-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")
train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)
if row['split'] == 'dev':
dev_samples.append(inp_example)
elif row['split'] == 'test':
test_samples.append(inp_example)
else:
train_samples.append(inp_example)
train_dataset = SentencesDataset(train_samples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')
# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path)
##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################
model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=model_save_path)
...@@ -20,7 +20,7 @@ parser.add_argument('--data_path', type=str, default='datasets/stsbenchmark.tsv. ...@@ -20,7 +20,7 @@ parser.add_argument('--data_path', type=str, default='datasets/stsbenchmark.tsv.
parser.add_argument('--train_batch_size', type=int, default=16) parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=10) parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2") parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
parser.add_argument('--model_save_path', type=str, default="output/training_sbert_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), help='Output folder') parser.add_argument('--save_root_path', type=str, default="output", help='Model output folder')
parser.add_argument('--lr', default=2e-05) parser.add_argument('--lr', default=2e-05)
args = parser.parse_args() args = parser.parse_args()
...@@ -33,13 +33,13 @@ if __name__ == "__main__": ...@@ -33,13 +33,13 @@ if __name__ == "__main__":
if not os.path.exists(sts_dataset_path): if not os.path.exists(sts_dataset_path):
util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path) util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
model_name_or_path = args.model_name_or_path model_name = args.model_name_or_path
train_batch_size = args.train_batch_size train_batch_size = args.train_batch_size
num_epochs = args.num_epochs num_epochs = args.num_epochs
model_save_path = args.model_save_path model_save_path = args.save_root_path + "/training_stsbenchmark_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Load a pre-trained sentence transformer model # Load a pre-trained sentence transformer model
model = SentenceTransformer(model_name_or_path, device='cuda') model = SentenceTransformer(model_name, device='cuda')
# Convert the dataset to a DataLoader ready for training # Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset") logging.info("Read STSbenchmark train dataset")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment