predict.sh 628 Bytes
Newer Older
yuguo-Jack's avatar
yuguo-Jack committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# gpu
root_dir="checkpoints" 
python -u -m paddle.distributed.launch --gpus "3" \
    predict.py \
    --device gpu \
    --params_path "${root_dir}/model_12000/model_state.pdparams" \
    --output_emb_size 256 \
    --batch_size 128 \
    --max_seq_length 64 \
    --model_name_or_path rocketqa-zh-base-query-encoder \
    --text_pair_file "recall/test.csv"

# cpu
# root_dir="checkpoints" 
# python  predict.py \
#     --device cpu \
#     --params_path "${root_dir}/model_20000/model_state.pdparams" \
#     --output_emb_size 256 \
#     --batch_size 128 \
#     --max_seq_length 64 \
#     --text_pair_file "recall/test.csv"