import os
import mat
import json
import logging
import argparse

from torch.utils.data import DataLoader
from datetime import datetime
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)

parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Input txt path')
parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
parser.add_argument('--model_save_path', type=str, default="output/training_sbert_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), help='Output folder')
args = parser.parse_args()

if __name__ == "__main__":

    sts_dataset_path = args.data_path
    # Check if dataset exists. If not, download and extract it
    if not os.path.exists(sts_dataset_path):
        print("datasets is not exists!!!!")
        exit()

    model_name_or_path = args.model_name_or_path
    train_batch_size = args.train_batch_size
    num_epochs = args.num_epochs
    model_save_path = args.model_save_path

    # Load a pre-trained sentence transformer model
    model = SentenceTransformer(model_name_or_path, device='cuda')

    # Convert the dataset to a DataLoader ready for training
    logging.info("Read STSbenchmark train dataset")
    # Read the dataset
    train_samples = []
    dev_samples = []
    with open(sts_dataset_path, "r", encoding="utf8") as fIn:
        count = 0
        for lineinfo in fIn.readlines():
            row = json.loads(lineinfo)
            score = float(row["score"]) # Normalize score to range 0 ... 1
            inp_example = InputExample(texts=[row["labels"], row["predict"]], label=score)

            if (count+1) % 5 == 0:
                dev_samples.append(inp_example)
            else:
                train_samples.append(inp_example)
            count += 1

    logging.info("Dealing data end.")
    train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
    train_loss = losses.CosineSimilarityLoss(model=model)

    # Development set: Measure correlation between cosine score and gold labels
    logging.info("Read STSbenchmark dev dataset")
    evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")

    # Configure the training. We skip evaluation in this example
    warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
    logging.info("Warmup-steps: {}".format(warmup_steps))

    print("Start training ...")
    # Train the model
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=num_epochs,
        evaluation_steps=1000,
        warmup_steps=warmup_steps,
        output_path=model_save_path,
    )
    logging.info("Finetune end")

    ##############################################################################
    #
    # Load the stored model and evaluate its performance on STS benchmark dataset
    #
    ##############################################################################

    model = SentenceTransformer(model_save_path)
    test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-test")
    test_evaluator(model, output_path=model_save_path)
