#!/usr/bin/env python
import argparse
import logging
import os
from ast import literal_eval

import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed
from mlperf_logging.mllog import constants

import seq2seq.data.config as config
import seq2seq.train.trainer as trainers
import seq2seq.utils as utils
from seq2seq.data.dataset import LazyParallelDataset, PreprocessedDataset
from seq2seq.data.dataset import TextDataset
from seq2seq.data.tokenizer import Tokenizer
from seq2seq.inference.inference import Translator
from seq2seq.models.gnmt import GNMT
from seq2seq.train.smoothing import LabelSmoothing
from seq2seq.utils import log_start, log_end, log_event, configure_logger
from seq2seq.train.lr_scheduler import WarmupMultiStepLR


def parse_args():
    """
    Parse commandline arguments.
    """
    def exclusive_group(group, name, default, help):
        destname = name.replace('-', '_')
        subgroup = group.add_mutually_exclusive_group(required=False)
        subgroup.add_argument(f'--{name}', dest=f'{destname}',
                              action='store_true',
                              help=f'{help} (use \'--no-{name}\' to disable)')
        subgroup.add_argument(f'--no-{name}', dest=f'{destname}',
                              action='store_false', help=argparse.SUPPRESS)
        subgroup.set_defaults(**{destname: default})

    parser = argparse.ArgumentParser(
        description='GNMT training',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # dataset
    dataset = parser.add_argument_group('dataset setup')
    dataset.add_argument('--dataset-dir', default='data/wmt16_de_en',
                         help='path to the directory with training/test data')
    dataset.add_argument('--max-size', default=None, type=int,
                         help='use at most MAX_SIZE elements from training \
                         dataset (useful for benchmarking), by default \
                         uses entire dataset')
    dataset.add_argument('--preproc-data-dir', default='/tmp/preprocessed',
                         help='path to the directory with preprocessed \
                         training/test data')
    exclusive_group(group=dataset, name='use-preproc-data', default=True,
                    help='use preprocessed dataset')

    # results
    results = parser.add_argument_group('results setup')
    results.add_argument('--results-dir', default='results',
                         help='path to directory with results, it will be \
                         automatically created if it does not exist')
    results.add_argument('--save', default='gnmt',
                         help='defines subdirectory within RESULTS_DIR for \
                         results from this training run')
    results.add_argument('--print-freq', default=10, type=int,
                         help='print log every PRINT_FREQ batches')

    # model
    model = parser.add_argument_group('model setup')
    model.add_argument('--hidden-size', default=1024, type=int,
                       help='model hidden size')
    model.add_argument('--num-layers', default=4, type=int,
                       help='number of RNN layers in encoder and in decoder')
    model.add_argument('--dropout', default=0.2, type=float,
                       help='dropout applied to input of RNN cells')

    exclusive_group(group=model, name='share-embedding', default=True,
                    help='use shared embeddings for encoder and decoder')

    model.add_argument('--smoothing', default=0.1, type=float,
                       help='label smoothing, if equal to zero model will use \
                       CrossEntropyLoss, if not zero model will be trained \
                       with label smoothing loss')

    # setup
    general = parser.add_argument_group('general setup')
    general.add_argument('--math', default='fp16',
                         choices=['fp32', 'fp16'],
                         help='arithmetic type')
    general.add_argument('--seed', default=None, type=int,
                         help='master seed for random number generators, if \
                         "seed" is undefined then the master seed will be \
                         sampled from random.SystemRandom()')
    general.add_argument('--prealloc-mode', default='always', type=str,
                         choices=['off', 'once', 'always'],
                         help='controls preallocation')

    exclusive_group(group=general, name='eval', default=False,
                    help='run validation and test after every epoch')
    exclusive_group(group=general, name='env', default=False,
                    help='print info about execution env')
    exclusive_group(group=general, name='cuda', default=True,
                    help='enables cuda')
    exclusive_group(group=general, name='cudnn', default=True,
                    help='enables cudnn')
    exclusive_group(group=general, name='log-all-ranks', default=True,
                    help='enables logging from all distributed ranks, if \
                    disabled then only logs from rank 0 are reported')
    exclusive_group(group=general, name='fused-attention', default=False,
                    help='enables fused attention')
    exclusive_group(group=general, name='fused-xentropy', default=True,
                    help='enables fused cross cross entropy with label \
                    smoothing')

    # training
    training = parser.add_argument_group('training setup')
    training.add_argument('--train-batch-size', default=128, type=int,
                          help='training batch size per worker')
    training.add_argument('--train-global-batch-size', default=None, type=int,
                          help='global training batch size, this argument \
                          does not have to be defined, if it is defined it \
                          will be used to automatically \
                          compute train_iter_size \
                          using the equation: train_iter_size = \
                          train_global_batch_size // (train_batch_size * \
                          world_size)')
    training.add_argument('--train-iter-size', metavar='N', default=1,
                          type=int,
                          help='training iter size, training loop will \
                          accumulate gradients over N iterations and execute \
                          optimizer every N steps')
    training.add_argument('--epochs', default=8, type=int,
                          help='max number of training epochs')

    training.add_argument('--grad-clip', default=5.0, type=float,
                          help='enables gradient clipping and sets maximum \
                          norm of gradients')
    training.add_argument('--max-length-train', default=50, type=int,
                          help='maximum sequence length for training \
                          (including special BOS and EOS tokens)')
    training.add_argument('--min-length-train', default=0, type=int,
                          help='minimum sequence length for training \
                          (including special BOS and EOS tokens)')
    training.add_argument('--train-loader-workers', default=1, type=int,
                          help='number of workers for training data loading')
    training.add_argument('--batching', default='bucketing', type=str,
                          choices=['random', 'sharding', 'bucketing'],
                          help='select batching algorithm')
    training.add_argument('--shard-size', default=80, type=int,
                          help='shard size for "sharding" batching algorithm, \
                          in multiples of global batch size')
    training.add_argument('--num-buckets', default=5, type=int,
                          help='number of buckets for "bucketing" batching \
                          algorithm')

    # optimizer
    optimizer = parser.add_argument_group('optimizer setup')
    optimizer.add_argument('--optimizer', type=str, default='Adam',
                           help='training optimizer')
    optimizer.add_argument('--lr', type=float, default=1.00e-3,
                           help='learning rate')
    optimizer.add_argument('--optimizer-extra', type=str,
                           default="{}",
                           help='extra options for the optimizer')

    # mixed precision loss scaling
    loss_scaling = parser.add_argument_group('mixed precision loss scaling \
                                             setup')
    loss_scaling.add_argument('--init-scale', type=float, default=1024,
                              help='initial loss scale')
    loss_scaling.add_argument('--upscale-interval', type=float, default=128,
                              help='loss upscaling interval')

    # scheduler
    scheduler = parser.add_argument_group('learning rate scheduler setup')
    scheduler.add_argument('--warmup-steps', type=str, default='200',
                           help='number of learning rate warmup iterations')
    scheduler.add_argument('--remain-steps', type=str, default='0.666',
                           help='starting iteration for learning rate decay')
    scheduler.add_argument('--decay-interval', type=str, default='None',
                           help='interval between learning rate decay steps')
    scheduler.add_argument('--decay-steps', type=int, default=4,
                           help='max number of learning rate decay steps')
    scheduler.add_argument('--decay-factor', type=float, default=0.5,
                           help='learning rate decay factor')

    # test
    test = parser.add_argument_group('test setup')
    test.add_argument('--test-batch-size', default=128, type=int,
                      help='batch size for test')
    test.add_argument('--max-length-test', default=150, type=int,
                      help='maximum sequence length for test \
                      (including special BOS and EOS tokens)')
    test.add_argument('--min-length-test', default=0, type=int,
                      help='minimum sequence length for test \
                      (including special BOS and EOS tokens)')
    test.add_argument('--beam-size', default=5, type=int,
                      help='beam size')
    test.add_argument('--len-norm-factor', default=0.6, type=float,
                      help='length normalization factor')
    test.add_argument('--cov-penalty-factor', default=0.1, type=float,
                      help='coverage penalty factor')
    test.add_argument('--len-norm-const', default=5.0, type=float,
                      help='length normalization constant')
    test.add_argument('--intra-epoch-eval', metavar='N', default=0, type=int,
                      help='evaluate within training epoch, this option will \
                      enable extra N equally spaced evaluations executed \
                      during each training epoch')
    test.add_argument('--test-loader-workers', default=0, type=int,
                      help='number of workers for test data loading')

    # checkpointing
    chkpt = parser.add_argument_group('checkpointing setup')
    chkpt.add_argument('--start-epoch', default=0, type=int,
                       help='manually set initial epoch counter')
    chkpt.add_argument('--resume', default=None, type=str, metavar='PATH',
                       help='resumes training from checkpoint from PATH')
    chkpt.add_argument('--save-all', action='store_true', default=False,
                       help='saves checkpoint after every epoch')
    chkpt.add_argument('--save-freq', default=5000, type=int,
                       help='save checkpoint every SAVE_FREQ batches')
    chkpt.add_argument('--keep-checkpoints', default=0, type=int,
                       help='keep only last KEEP_CHECKPOINTS checkpoints, \
                       affects only checkpoints controlled by --save-freq \
                       option')

    # benchmarking
    benchmark = parser.add_argument_group('benchmark setup')
    benchmark.add_argument('--target-bleu', default=24.0, type=float,
                           help='target accuracy, training will be stopped \
                           when the target is achieved')

    # distributed
    distributed = parser.add_argument_group('distributed setup')
    distributed.add_argument('--rank', default=0, type=int,
                             help='global rank of the process, do not set!')
    distributed.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int,
                             help='local rank of the process, do not set!')
    distributed.add_argument('--enable-apex-allreduce-overlap',
                             action='store_true', default=False,
                             help='enable overlap of allreduce communication \
                             with bprop')
    distributed.add_argument('--apex-num-allreduce-streams',
                             default=1, type=int,
                             help='num. allreduce streams')
    distributed.add_argument('--apex-message-size', default=1e7, type=int,
                             help='min. number of elements in communication \
                             bucket')

    # distributed weight update
    dwu_group = parser.add_argument_group('distributed weight update setup')
    dwu_group.add_argument('--distributed-weight-update', '--dwu', default=0, type=int, metavar='DWU',
                       help='select distributed weight update strategy')
    dwu_group.add_argument('--dwu-group-size', '--dwugs', default=0, type=int, metavar='DWUGS',
                       help='distributed weight update group size. If arg is 0, defaults to one node')
    dwu_group.add_argument('--dwu-num-blocks', '--dwunb', default=8, type=int, metavar='DWUNB',
                       help='number of blocks in dwu scheme')
    dwu_group.add_argument('--dwu-num-chunks', '--dwuchks', default=4, type=int,
                       help='number of chunks of each parameters block')
    dwu_group.add_argument('--dwu-num-rs-pg', '--dwurspg', default=2, type=int, metavar='DWURSPG',
                       help='number of reduction-scatter streams in dwu scheme')
    dwu_group.add_argument('--dwu-num-ar-pg', '--dwuarpg', default=4, type=int, metavar='DWUARPG',
                       help='number of all-reduce streams in dwu scheme')
    dwu_group.add_argument('--dwu-num-ag-pg', '--dwuagpg', default=2, type=int, metavar='DWUAGPG',
                       help='number of all-gather streams in dwu scheme')
    dwu_group.add_argument('--dwu-full-pipeline', action='store_true', 
                       help='whether to do full or partial pipeline')
    dwu_group.add_argument('--dwu-overlap-reductions', action='store_true',
                       help='whether to overlap reductions with backprop')
    dwu_group.add_argument('--dwu-grad-norm', action='store_true',
                       help='whether to compute L2 grad norm')
    dwu_group.add_argument('--dwu-e5m2-allgather', action='store_true',
                       help='whether to use e5m2 allgather')

    args = parser.parse_args()

    args.warmup_steps = literal_eval(args.warmup_steps)
    args.remain_steps = literal_eval(args.remain_steps)
    args.decay_interval = literal_eval(args.decay_interval)

    return args


def build_criterion(vocab_size, padding_idx, smoothing, fusion):
    if smoothing == 0.:
        logging.info(f'Building CrossEntropyLoss')
        loss_weight = torch.ones(vocab_size)
        loss_weight[padding_idx] = 0
        criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False)
    else:
        logging.info(f'Building LabelSmoothingLoss (smoothing: {smoothing})')
        criterion = LabelSmoothing(padding_idx, smoothing, fusion=fusion)
    return criterion


def main():
    """
    Launches data-parallel multi-gpu training.
    """

    configure_logger(constants.GNMT)
    log_start(key=constants.INIT_START, log_all_ranks=True)

    args = parse_args()
    device = utils.set_device(args.cuda, args.local_rank)
    distributed = utils.init_distributed(args.cuda)

    # preinit and warmup streams/ groups for apex DDP communicators
    # distributed weight update doesn't require this
    allreduce_communicators=None
    if distributed and args.distributed_weight_update == 0 and \
            args.apex_num_allreduce_streams > 1:
        bucket_pgs = [torch.distributed.new_group() for _ in range(args.apex_num_allreduce_streams)]
        bucket_streams = [torch.cuda.Stream() for _ in range(args.apex_num_allreduce_streams)]
        for pg, stream in zip(bucket_pgs,bucket_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1), group=pg)
        allreduce_communicators=(bucket_pgs,bucket_streams)

    args.rank = utils.get_rank()

    if not args.cudnn:
        torch.backends.cudnn.enabled = False

    # create directory for results
    save_path = os.path.join(args.results_dir, args.save)
    args.save_path = save_path
    os.makedirs(save_path, exist_ok=True)

    # setup logging
    log_filename = f'log_rank_{utils.get_rank()}.log'
    utils.setup_logging(args.log_all_ranks,
                        os.path.join(save_path, log_filename))

    if args.env:
        utils.log_env_info()

    logging.info(f'Saving results to: {save_path}')
    logging.info(f'Run arguments: {args}')

    # additional argument check
    if args.math == 'fp32' and (args.fused_attention or args.fused_xentropy):
        logging.warn(f'Only support FP16 `--fused-attention` and '
            '`--fused-xentropy`, disabling them')
        args.fused_attention = args.fused_xentropy = False

    # automatically set train_iter_size based on train_global_batch_size,
    # world_size and per-worker train_batch_size
    if args.train_global_batch_size is not None:
        global_bs = args.train_global_batch_size
        bs = args.train_batch_size
        world_size = utils.get_world_size()
        assert global_bs % (bs * world_size) == 0
        args.train_iter_size = global_bs // (bs * world_size)
        logging.info(f'Global batch size was set in the config, '
                     f'Setting train_iter_size to {args.train_iter_size}')
    # setup L2 promotion
    #aiss 
    #if args.cuda:
    #    utils.l2_promote()

    worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.epochs,
                                                      device)
    worker_seed = worker_seeds[args.rank]
    logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}')
    torch.manual_seed(worker_seed)

    # build tokenizer
    # https://github.com/mlperf/policies/issues/201
    pad_vocab = utils.pad_vocabulary(args.math)
    tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME),
                          pad_vocab)

    vocab_size = tokenizer.vocab_size

    # build GNMT model
    model_config = {'hidden_size': args.hidden_size,
                    'num_layers': args.num_layers,
                    'dropout': args.dropout, 'batch_first': False,
                    'share_embedding': args.share_embedding,
                    'fusion': args.fused_attention}
    model = GNMT(vocab_size=vocab_size, **model_config)
    logging.info(model)

    # define loss function (criterion) and optimizer
    criterion = build_criterion(vocab_size, config.PAD, args.smoothing,
                                args.fused_xentropy)

    opt_config = {'optimizer': args.optimizer, 'lr': args.lr}
    opt_config.update(literal_eval(args.optimizer_extra))
    logging.info(f'Training optimizer config: {opt_config}')

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info(f'Number of parameters: {num_parameters}')

    # create trainer
    save_info = {'model_config': model_config, 'config': args, 'tokenizer':
                 tokenizer.get_state()}
    loss_scaling = {'init_scale': args.init_scale, 'upscale_interval':
                    args.upscale_interval}
    trainer_options = dict(
        criterion=criterion,
        grad_clip=args.grad_clip,
        iter_size=args.train_iter_size,
        save_path=save_path,
        save_freq=args.save_freq,
        save_info=save_info,
        opt_config=opt_config,
        batch_first=model.batch_first,
        keep_checkpoints=args.keep_checkpoints,
        math=args.math,
        loss_scaling=loss_scaling,
        print_freq=args.print_freq,
        cuda=args.cuda,
        distributed=distributed,
        distributed_overlap_allreduce=args.enable_apex_allreduce_overlap,
        distributed_overlap_num_allreduce_streams=args.apex_num_allreduce_streams,
        distributed_overlap_allreduce_messagesize=args.apex_message_size,
        distributed_overlap_allreduce_communicators=allreduce_communicators,
        intra_epoch_eval=args.intra_epoch_eval,
        prealloc_mode=args.prealloc_mode)

    trainer_options['model'] = model
    #########aiss##########
    #import pdb
    #pdb.set_trace()
    trainer = trainers.Seq2SeqTrainer(args=args, **trainer_options)

    trainer.preallocate(args.train_batch_size, args.max_length_train,
                        training=True)

    log_end(key=constants.INIT_STOP, sync=False)
    log_start(key=constants.RUN_START, sync=True)
    utils.barrier()

    log_event(key=constants.MAX_SEQUENCE_LENGTH,
              value=args.max_length_train,
              metadata={'method': 'discard'})

    if args.use_preproc_data:
        train_data = PreprocessedDataset(
            min_len=args.min_length_train,
            max_len=args.max_length_train,
            vocab_size=tokenizer.vocab_size,
            )
        train_data.read_data(
            os.path.join(args.preproc_data_dir, 'training.bin'),
            tokenizer.vocab_size,
            )
        train_data.prepare()
    else:
        train_data = LazyParallelDataset(
            src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME),
            tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME),
            tokenizer=tokenizer,
            min_len=args.min_length_train,
            max_len=args.max_length_train,
            sort=False,
            max_size=args.max_size,
            )

    test_data = TextDataset(
        src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_test,
        max_len=args.max_length_test,
        sort=True)

    batching_opt = {'shard_size': args.shard_size,
                    'num_buckets': args.num_buckets}

    # get data loaders
    train_loader = train_data.get_loader(batch_size=args.train_batch_size,
                                         seeds=shuffling_seeds,
                                         batch_first=model.batch_first,
                                         shuffle=True,
                                         batching=args.batching,
                                         batching_opt=batching_opt,
                                         num_workers=args.train_loader_workers)

    log_event(key=constants.GLOBAL_BATCH_SIZE,
              value=args.train_batch_size * utils.get_world_size(),
              sync=False)

    test_loader = test_data.get_loader(batch_size=args.test_batch_size,
                                       batch_first=model.batch_first,
                                       shuffle=False,
                                       num_workers=args.test_loader_workers)

    log_event(key=constants.TRAIN_SAMPLES,
              value=train_loader.sampler.num_samples, sync=False)
    log_event(key=constants.EVAL_SAMPLES,
              value=len(test_loader.dataset), sync=False)

    translator = Translator(model=model,
                            tokenizer=tokenizer,
                            loader=test_loader,
                            beam_size=args.beam_size,
                            max_seq_len=args.max_length_test,
                            len_norm_factor=args.len_norm_factor,
                            len_norm_const=args.len_norm_const,
                            cov_penalty_factor=args.cov_penalty_factor,
                            cuda=args.cuda,
                            print_freq=args.print_freq,
                            dataset_dir=args.dataset_dir,
                            target_bleu=args.target_bleu,
                            save_path=args.save_path)

    total_train_iters = len(train_loader) // args.train_iter_size * args.epochs

    scheduler_config = {'warmup_steps': args.warmup_steps,
                        'remain_steps': args.remain_steps,
                        'decay_interval': args.decay_interval,
                        'decay_steps': args.decay_steps,
                        'decay_factor': args.decay_factor}

    logging.info(f'Training LR schedule config: {scheduler_config}')
    scheduler = WarmupMultiStepLR(trainer.optimizer, total_train_iters,
                                  **scheduler_config)
    trainer.scheduler = scheduler
    trainer.translator = translator

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            checkpoint_file = os.path.join(
                checkpoint_file, 'model_best.pth')
        if os.path.isfile(checkpoint_file):
            trainer.load(checkpoint_file)
        else:
            logging.error(f'No checkpoint found at {args.resume}')

    # training loop
    break_training = False
    test_bleu = None
    for epoch in range(args.start_epoch, args.epochs):
        log_start(key=constants.BLOCK_START,
                  metadata={'first_epoch_num': epoch + 1,
                            'epoch_count': 1},
                  sync=False)
        log_start(key=constants.EPOCH_START,
                  metadata={'epoch_num': epoch + 1},
                  sync=False)

        logging.info(f'Starting epoch {epoch}')
        train_loader.sampler.set_epoch(epoch)

        trainer.epoch = epoch
        train_loss, train_perf = trainer.optimize(train_loader)

        log_end(key=constants.EPOCH_STOP,
                metadata={'epoch_num': epoch + 1},
                sync=False)

        if args.eval:
            log_start(key=constants.EVAL_START,
                      metadata={'epoch_num': epoch + 1},
                      sync=False)
            test_bleu, break_training = translator.run(calc_bleu=True,
                                                       epoch=epoch)
            log_event(key=constants.EVAL_ACCURACY,
                      value=test_bleu / 100,
                      metadata={'epoch_num': epoch + 1},
                      sync=False)
            log_end(key=constants.EVAL_STOP,
                    metadata={'epoch_num': epoch + 1},
                    sync=False)

        acc_log = []
        acc_log += [f'Summary: Epoch: {epoch}']
        acc_log += [f'Training Loss: {train_loss:.4f}']
        if args.eval:
            acc_log += [f'Test BLEU: {test_bleu:.2f}']

        perf_log = []
        perf_log += [f'Performance: Epoch: {epoch}']
        perf_log += [f'Training: {train_perf:.0f} Tok/s']

        if args.rank == 0:
            logging.info('\t'.join(acc_log))
            logging.info('\t'.join(perf_log))

        logging.info(f'Finished epoch {epoch}')
        log_end(key=constants.BLOCK_STOP,
                metadata={'first_epoch_num': epoch + 1},
                sync=False)

        if break_training:
            break

    if args.use_preproc_data:
        train_data.finalize()

    status = 'success' if break_training else 'aborted'
    log_end(key=constants.RUN_STOP,
            metadata={'status': status},
            sync=False)


if __name__ == '__main__':
    main()
