Commit a5bc7a53 authored by Rayyyyy's avatar Rayyyyy
Browse files

Add stsbenchmark datasets train.

parent 0fccd232
...@@ -6,14 +6,14 @@ ...@@ -6,14 +6,14 @@
## 模型结构 ## 模型结构
<div align=center> <div align=center>
<img src="./doc/model.png"/> <img src="./doc/model.png" width=300 height=400/>
</div> </div>
## 算法原理 ## 算法原理
对于每个句子对,通过网络传递句子A和句子B,从而得到embeddings u 和 v。使用余弦相似度计算embedding的相似度,并将结果与 gold similarity score进行比较。这允许网络进行微调,并识别句子的相似性. 对于每个句子对,通过网络传递句子A和句子B,从而得到embeddings u 和 v。使用余弦相似度计算embedding的相似度,并将结果与 gold similarity score进行比较。这允许网络进行微调,并识别句子的相似性.
<div align=center> <div align=center>
<img src="./doc/infer.png"/> <img src="./doc/infer.png" width=500 height=520/>
</div> </div>
## 环境配置 ## 环境配置
...@@ -37,9 +37,9 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk2 ...@@ -37,9 +37,9 @@ docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk2
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/sentence-bert_pytorch cd /your_code_path/sentence-bert_pytorch
pip install -r requirements.txt
pip install -U sentence-transformers
pip install -e . pip install -e .
pip install -U huggingface_hub hf_transfer
export HF_ENDPOINT=https://hf-mirror.com
``` ```
### Dockerfile(方法二) ### Dockerfile(方法二)
...@@ -52,9 +52,9 @@ docker build --no-cache -t sbert:latest . ...@@ -52,9 +52,9 @@ docker build --no-cache -t sbert:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/sentence-bert_pytorch cd /your_code_path/sentence-bert_pytorch
pip install -r requirements.txt
pip install -U sentence-transformers
pip install -e . pip install -e .
pip install -U huggingface_hub hf_transfer
export HF_ENDPOINT=https://hf-mirror.com
``` ```
### Anaconda(方法三) ### Anaconda(方法三)
...@@ -72,48 +72,49 @@ Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格 ...@@ -72,48 +72,49 @@ Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格
```bash ```bash
cd /your_code_path/sentence-bert_pytorch cd /your_code_path/sentence-bert_pytorch
pip install -r requirements.txt
pip install -U sentence-transformers
pip install -e . pip install -e .
pip install -U huggingface_hub hf_transfer
export HF_ENDPOINT=https://hf-mirror.com
``` ```
## 数据集 ## 数据集
使用来自多个数据集的結合来微调模型,句子对的总数超过10亿个句子。对每个数据集进行抽样,给出一个加权概率,该概率在data_config.json文件中详细说明。 使用来自多个数据集的結合来微调模型,句子对的总数超过10亿个句子。对每个数据集进行抽样,给出一个加权概率,该概率在data_config.json文件中详细说明。
因数据较多,这里仅用[Simple Wikipedia Version 1.0](https://cs.pomona.edu/~dkauchak/simplification/)数据集进行展示,数据集已在 datasets/simple_wikipedia_v1 中提供 因数据较多,这里仅用[Simple Wikipedia Version 1.0](https://cs.pomona.edu/~dkauchak/simplification/)数据集进行展示,数据集已在 datasets/simple_wikipedia_v1 中提供,详细数据请参考[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型中的Model card。
详细数据请参考[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型中的Model card。
数据集的目录结构如下: 数据集的目录结构如下:
``` ```
├── datasets ├── datasets
│ ├──tmp.txt │ ├──stsbenchmark.tsv.gz
│ ├──simple_wikipedia_v1 │ ├──simple_wikipedia_v1
│ ├──simple_wiki_pair.txt # 生成的 │ ├──simple_wiki_pair.txt # 生成的
│ ├──wiki.simple │ ├──wiki.simple
│ └──wiki.unsimplified │ └──wiki.unsimplified
``` ```
推理数据需要转换成txt格式,参考[gen_simple_wikipedia_v1.py](./gen_simple_wikipedia_v1.py)文件,生成`simple_wiki_pair.txt`
## 训练 ## 训练
使用预训练模型[MiniLM-L6-H384-uncased](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased),有关预训练程序的详细信息,请参阅 model card。 默认使用预训练模型[MiniLM-L6-H384-uncased](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased)进行finetune训练,有关预训练程序的详细信息,请参阅 model card。
### 单机多卡 ### 单机多卡
```bash ```bash
bash finetune.sh bash finetune.sh
``` ```
### 单机单卡 ### 单机单卡
```bash ```bash
python finetune.py python finetune.py
``` ```
## 推理 ## 推理
预训练模型下载[pretrained models](https://www.sbert.net/docs/pretrained_models.html) 1. 预训练模型下载[pretrained models](https://www.sbert.net/docs/pretrained_models.html), 当前默认为[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)模型;
2. 执行以下命令,测试数据默认为`./datasets/simple_wikipedia_v1/simple_wiki_pair.txt`,可修改`--data_path`参数为其他待测文件地址,文件内容格式请参考[simple_wiki_pair.txt](./datasets/simple_wikipedia_v1/simple_wiki_pair.txt)
```bash ```bash
python infer.py --data_path ./datasets/tmp.txt python infer.py --data_path ./datasets/simple_wikipedia_v1/simple_wiki_pair.txt --model_name_or_path all-MiniLM-L6-v2
``` ```
## result ## result
......
...@@ -102,7 +102,7 @@ for sentence, embedding in zip(sentences, sentence_embeddings): ...@@ -102,7 +102,7 @@ for sentence, embedding in zip(sentences, sentence_embeddings):
print("Embedding:", embedding) print("Embedding:", embedding)
print("") print("")
```` ````
bbnnm,,,nmm
## Pre-Trained Models ## Pre-Trained Models
We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`. We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`.
......
{"sentence1": "不能,这是属于个人所有的固定资产。", "sentence2": "不可以,这是个人固定资产,不能买卖。", "score": 0.96}
{"sentence1": "不可以,这属于个人固定资产,不能交易。", "sentence2": "不可以,这属于个人固定资产。", "score": 0.99}
{"sentence1": "活动前一周内是推荐的提交时间段。", "sentence2": "通常建议在活动开始前的一周内提交。", "score": 0.99}
{"sentence1": "请一直向参观者强调“不要拍照”。", "sentence2": "请提醒参观者“禁止携带相机拍照”。", "score": 0.85}
{"sentence1": "可以自己选购所需物资。", "sentence2": "可以自行选购,没有限制。", "score": 0.85}
\ No newline at end of file
doc/model.png

72.7 KB | W: | H:

doc/model.png

23.9 KB | W: | H:

doc/model.png
doc/model.png
doc/model.png
doc/model.png
  • 2-up
  • Swipe
  • Onion skin
import os import os
import math import math
import json import gzip
import csv
import logging import logging
import argparse import argparse
import torch
from datetime import datetime from datetime import datetime
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
#### Just some code to print debug information to stdout #### Just some code to print debug information to stdout
...@@ -16,7 +16,7 @@ logging.basicConfig( ...@@ -16,7 +16,7 @@ logging.basicConfig(
) )
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='./datasets/tmp.txt', help='Input txt path') parser.add_argument('--data_path', type=str, default='datasets/stsbenchmark.tsv.gz', help='Input txt path')
parser.add_argument('--train_batch_size', type=int, default=16) parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=10) parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2") parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
...@@ -28,10 +28,10 @@ args = parser.parse_args() ...@@ -28,10 +28,10 @@ args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
sts_dataset_path = args.data_path sts_dataset_path = args.data_path
# Check if dataset exists. If not, download and extract it # Check if dataset exists. If not, download and extract it
if not os.path.exists(sts_dataset_path): if not os.path.exists(sts_dataset_path):
print("datasets is not exists!!!!") util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
exit()
model_name_or_path = args.model_name_or_path model_name_or_path = args.model_name_or_path
train_batch_size = args.train_batch_size train_batch_size = args.train_batch_size
...@@ -41,27 +41,28 @@ if __name__ == "__main__": ...@@ -41,27 +41,28 @@ if __name__ == "__main__":
# Load a pre-trained sentence transformer model # Load a pre-trained sentence transformer model
model = SentenceTransformer(model_name_or_path, device='cuda') model = SentenceTransformer(model_name_or_path, device='cuda')
# Convert the dataset to a DataLoader ready for training # Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset") logging.info("Read STSbenchmark train dataset")
# Read the dataset # Read the dataset
train_samples = [] train_samples = []
dev_samples = [] dev_samples = []
with open(sts_dataset_path, "r", encoding="utf8") as fIn: test_samples = []
count = 0 with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
for lineinfo in fIn.readlines(): reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
row = json.loads(lineinfo) for row in reader:
score = float(row["score"]) # Normalize score to range 0 ... 1 score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)
if (count+1) % 5 == 0: if row['split'] == 'dev':
dev_samples.append(inp_example) dev_samples.append(inp_example)
elif row['split'] == 'test':
test_samples.append(inp_example)
else: else:
train_samples.append(inp_example) train_samples.append(inp_example)
count += 1
logging.info("Dealing data end.") logging.info("Dealing data end.")
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) train_dataset = SentencesDataset(train_samples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model) train_loss = losses.CosineSimilarityLoss(model=model)
# Development set: Measure correlation between cosine score and gold labels # Development set: Measure correlation between cosine score and gold labels
...@@ -92,5 +93,5 @@ if __name__ == "__main__": ...@@ -92,5 +93,5 @@ if __name__ == "__main__":
############################################################################## ##############################################################################
model = SentenceTransformer(model_save_path) model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-test") test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=model_save_path) test_evaluator(model, output_path=model_save_path)
...@@ -7,7 +7,4 @@ export USE_MIOPEN_BATCHNORM=1 ...@@ -7,7 +7,4 @@ export USE_MIOPEN_BATCHNORM=1
echo "Training start ..." echo "Training start ..."
python -m torch.distributed.launch --use_env --nproc_per_node=4 --master_port=4321 finetune.py \ torchrun --nproc_per_node=4 finetune.py --train_batch_size 16 --num_epochs 5
--data_path ./datasets/tmp.txt \
--train_batch_size 32 \
--num_epochs 10
...@@ -39,8 +39,8 @@ if __name__ == "__main__": ...@@ -39,8 +39,8 @@ if __name__ == "__main__":
print('dealing with:', line.strip()) print('dealing with:', line.strip())
json_info = json.loads(line) json_info = json.loads(line)
# Sentences are encoded by calling model.encode() # Sentences are encoded by calling model.encode()
label_emb = model.encode(json_info.get("labels")) label_emb = model.encode(json_info.get("sentence1"))
pred_emb = model.encode(json_info.get("predict")) pred_emb = model.encode(json_info.get("sentence2"))
cos_sim = util.cos_sim(label_emb, pred_emb) cos_sim = util.cos_sim(label_emb, pred_emb)
json_info["score"] = cos_sim.item() json_info["score"] = cos_sim.item()
print("Cosine-Similarity:", cos_sim.item()) print("Cosine-Similarity:", cos_sim.item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment