# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import paddle

from paddlenlp import Taskflow


def parse_args():
    parser = argparse.ArgumentParser()

    # fmt: off
    parser.add_argument("--max_seq_len", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.", )
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.", )
    parser.add_argument("--device", default="gpu", type=str, choices=["cpu", "gpu", "xpu"], help="The device to select to train the model, is must be cpu/gpu/xpu.")
    # fmt: on

    args = parser.parse_args()
    return args


def do_predict(args):
    paddle.set_device(args.device)
    wordtag = Taskflow(
        "knowledge_mining", model="wordtag", batch_size=args.batch_size, max_seq_length=args.max_seq_len, linking=True
    )
    txts = ["《孤女》是2010年九州出版社出版的小说，作者是余兼羽。", "热梅茶是一道以梅子为主要原料制作的茶饮"]
    res = wordtag(txts)
    print(res)


def print_arguments(args):
    """print arguments"""
    print("-----------  Configuration Arguments -----------")
    for arg, value in sorted(vars(args).items()):
        print("%s: %s" % (arg, value))
    print("------------------------------------------------")


if __name__ == "__main__":
    args = parse_args()
    print_arguments(args)
    do_predict(args)
