run.py 991 Bytes
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# encoding=gbk
import os
import sys

sys.path.append('../')

from transformers import HfArgumentParser

from search_demo.tool import LocalDatasetLoader, BMVectorIndex, Agent
from search_demo.arguments import ModelArguments, DataArguments

if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, DataArguments))
    model_args, data_args = parser.parse_args_into_dataclasses()
    loader = LocalDatasetLoader(data_path=os.path.join(data_args.data_path, 'dataset'),
                                embedding_path=os.path.join(data_args.data_path, 'emb/data.npy'))
    index = BMVectorIndex(model_path=model_args.model_name_or_path,
                          bm_index_path=os.path.join(data_args.data_path, 'index'),
                          data_loader=loader)
    agent = Agent(index)
    while True:
        question = input("ʣ").strip()
        if question != '':
            agent.answer(question, RANKING=1000, TOP_N=5, verbose=True)
        else:
            break