"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "00cb04a255917d4a85a0f7fd611f75dc44c0de1a"
README.md 3.44 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
# Reranker
## Usage
和embedding模型不同,Reranker使用问题和文档作为输入,直接输出相似度而不是嵌入。
您可以通过输入查询语句和段落文本到reranker来获得相关性评分。
Rayyyyy's avatar
Rayyyyy committed
5
reranker是基于交叉熵损失函数进行优化的,因此相关分数不局限于特定的范围。
Rayyyyy's avatar
Rayyyyy committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

### 使用 FlagEmbedding
1. 确认环境配置完成,请参考[环境配置](../../README.md#环境配置)
2. 计算相关性得分(相关度越高得分越高):
```python
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

score = reranker.compute_score(['query', 'passage'])
print(score)

scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores)
```

### 使用 Huggingface transformers
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    print(scores)
```

## 微调
可以跟着这个[用例](../../examples/reranker/)来微调reranker。

reranker采用了[xlm-roberta-base](https://huggingface.co/xlm-roberta-base)进行初始化,并且我们使用了混合的多语言数据集来进行训练
- 中文: 788,491 文本对来自[T2ranking](https://huggingface.co/datasets/THUIR/T2Ranking), [MMmarco](https://github.com/unicamp-dl/mMARCO), [dulreader](https://github.com/baidu/DuReader), [Cmedqa-v2](https://github.com/zhangsheng93/cMedQA2), 和 [nli-zh](https://huggingface.co/datasets/shibing624/nli_zh)
- 英语: 933,090文本对来自[msmarco](https://huggingface.co/datasets/sentence-transformers/embedding-training-data), [nq](https://huggingface.co/datasets/sentence-transformers/embedding-training-data), [hotpotqa](https://huggingface.co/datasets/sentence-transformers/embedding-training-data), 和 [NLI](https://github.com/princeton-nlp/SimCSE)
- 其他: 97,458文本对来自[Mr.TyDi](https://github.com/castorini/mr.tydi) (包括阿拉伯语、孟加拉语、英语、芬兰语、印度尼西亚语、日语、韩语、俄语、斯瓦希里语、泰卢固语、泰语)

为了加强交叉语言相关性功能,我们基于[MMarco](https://github.com/unicamp-dl/mMARCO)构造了两个交叉语言检索的数据集。
特别地,我们对10万个英文queries进行抽样以检索中文段落,同时对10万个中文queries进行抽样以检索英文段落。数据集发布于[Shitao/bge-reranker-data](https://huggingface.co/datasets/Shitao/bge-reranker-data)
当前,模型主要支持中英文,并且可能会看到其他低资源语言的性能下降。

## Evaluation
您可以使用我们的[c-mteb script](../../C_MTEB#evaluate-reranker)进行reranker验证。

## Acknowledgement
部分代码基于[Reranker](https://github.com/luyug/Reranker)进行的开发。