# Copyright 2021 Max Planck Institute for Software Systems, and
# National University of Singapore
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import asyncio
import shlex
import itertools
import traceback
import typing as tp

from simbricks.experiment.experiment_environment import ExpEnv
from simbricks.experiment.experiment_output import ExpOutput
from simbricks.exectools import Executor, SimpleComponent
from simbricks.proxy import NetProxyConnecter, NetProxyListener, SimProxy
from simbricks.simulators import HostSim, NICSim, NetSim, PCIDevSim, Simulator
import simbricks.utils.graphlib as graphlib

class Experiment(object):
    """Describes a simulation experiment. Holds information about the simulators
    to run and paramaters to configure the experiment."""
    name: str
    """This experiment's name. Can be used to filter multiple experiments to be
    run."""
    timeout: int
    """Timeout for experiment in seconds."""
    checkpoint = False
    """Whether to use checkpoints in experiment.
    
    Can for example be used to speed up booting a host simulator by first
    running in a less accurate mode. Before we then start the application we are
    interested in, a checkpoint is taken and the simulator shut down. Then, the
    simulator is restored in the accurate mode using this checkpoint."""
    no_simbricks = False
    """`true` - No simbricks adapters are used in the simulators."""

    def __init__(self, name: str):
        self.name = name
        self.hosts: tp.List[HostSim] = []
        self.pcidevs: tp.List[PCIDevSim] = []
        self.networks: tp.List[NetSim] = []
        self.metadata = {}

    def add_host(self, sim: HostSim):
        for h in self.hosts:
            if h.name == sim.name:
                raise Exception('Duplicate host name')
        self.hosts.append(sim)

    def add_nic(self, sim: NICSim):
        self.add_pcidev(sim)

    def add_pcidev(self, sim: PCIDevSim):
        for d in self.pcidevs:
            if d.name == sim.name:
                raise Exception('Duplicate pcidev name')
        self.pcidevs.append(sim)

    def add_network(self, sim: NetSim):
        for n in self.networks:
            if n.name == sim.name:
                raise Exception('Duplicate net name')
        self.networks.append(sim)

    def all_simulators(self):
        """ All simulators used in experiment. """
        return itertools.chain(self.hosts, self.pcidevs, self.networks)

    def resreq_mem(self):
        """Memory required to run all simulators used in this experiment."""
        mem = 0
        for s in self.all_simulators():
            mem += s.resreq_mem()
        return mem

    def resreq_cores(self):
        """Number of Cores required to run all simulators used in this
        experiment."""
        cores = 0
        for s in self.all_simulators():
            cores += s.resreq_cores()
        return cores

class DistributedExperiment(Experiment):
    """Describes a distributed simulation experiment. """
    num_hosts = 1
    """Number of hosts to use."""
    host_mapping: tp.Dict[Simulator, int]
    """Mapping from simulator to host ID."""
    proxies_listen: tp.List[NetProxyListener]
    proxies_connect: tp.List[NetProxyConnecter]

    def __init__(self, name: str, num_hosts: int):
        self.num_hosts = num_hosts
        self.host_mapping = {}
        self.proxies_listen = []
        self.proxies_connect = []
        super().__init__(name)

    def add_proxy(self, proxy: SimProxy):
        if proxy.listen:
            self.proxies_listen.append(proxy)
        else:
            self.proxies_connect.append(proxy)

    def all_simulators(self):
        return itertools.chain(super().all_simulators(),
                self.proxies_listen, self.proxies_connect)

    def assign_sim_host(self, sim, host):
        """ Assign host ID (< self.num_hosts) for a simulator. """
        assert(host >= 0 and host < self.num_hosts)
        self.host_mapping[sim] = host


    def all_sims_assigned(self):
        """ Check if all simulators are assigned to a host. """
        for s in self.all_simulators():
            if s not in self.host_mapping:
                return False
        return True


class ExperimentBaseRunner(object):
    def __init__(self, exp: Experiment, env: ExpEnv, verbose: bool):
        self.exp = exp
        self.env = env
        self.verbose = verbose
        self.out = ExpOutput(exp)
        self.running: tp.List[tp.Tuple[Simulator, SimpleComponent]] = []
        self.sockets = []
        self.wait_sims = []

    def sim_executor(self, sim: Simulator) -> Executor:
        raise NotImplementedError("Please implement this method")

    def sim_graph(self):
        sims = self.exp.all_simulators()
        graph = {}
        for sim in sims:
            deps = sim.dependencies() + sim.extra_deps
            graph[sim] = set()
            for d in deps:
                graph[sim].add(d)
        return graph

    async def start_sim(self, sim: Simulator):
        """ Start a simulator and wait for it to be ready. """

        name = sim.full_name()
        if self.verbose:
            print('%s: starting %s' % (self.exp.name, name))

        run_cmd = sim.run_cmd(self.env)
        if run_cmd is None:
            if self.verbose:
                print('%s: started dummy %s' % (self.exp.name, name))
            return

        # run simulator
        exec = self.sim_executor(sim)
        sc = exec.create_component(name,
                    shlex.split(run_cmd), verbose=self.verbose,
                    canfail=True)
        await sc.start()
        self.running.append((sim, sc))

        # add sockets for cleanup
        for s in sim.sockets_cleanup(self.env):
            self.sockets.append((exec, s))

        # Wait till sockets exist
        wait_socks = sim.sockets_wait(self.env)
        if wait_socks:
            if self.verbose:
                print('%s: waiting for sockets %s' % (self.exp.name, name))

            await exec.await_files(wait_socks, verbose=self.verbose)

        # add time delay if required
        delay = sim.start_delay()
        if delay > 0:
            await asyncio.sleep(delay)

        if sim.wait_terminate():
            self.wait_sims.append(sc)

        if self.verbose:
            print('%s: started %s' % (self.exp.name, name))

    async def before_wait(self):
        pass

    async def before_cleanup(self):
        pass

    async def after_cleanup(self):
        pass


    async def prepare(self):
        # generate config tars
        copies = []
        for host in self.exp.hosts:
            path = self.env.cfgtar_path(host)
            if self.verbose:
                print('preparing config tar:', path)
            host.node_config.make_tar(path)
            copies.append(self.sim_executor(host).send_file(path, self.verbose))
        await asyncio.wait(copies)

        # prepare all simulators in parallel
        sims = []
        for sim in self.exp.all_simulators():
            prep_cmds = [pc for pc in sim.prep_cmds(self.env)]
            exec = self.sim_executor(sim)
            sims.append(exec.run_cmdlist('prepare_' + self.exp.name, prep_cmds,
                verbose=self.verbose))
        await asyncio.wait(sims)

    async def wait_for_sims(self):
        """ Wait for simulators to terminate (the ones marked to wait on). """
        if self.verbose:
            print('%s: waiting for hosts to terminate' % self.exp.name)
        for sc in self.wait_sims:
            await sc.wait()

    async def run(self):
        try:
            self.out.set_start()

            graph = self.sim_graph()
            ts = graphlib.TopologicalSorter(graph)
            ts.prepare()
            while ts.is_active():
                # start ready simulators in parallel
                starts = []
                sims = []
                for sim in ts.get_ready():
                    starts.append(self.start_sim(sim))
                    sims.append(sim)

                # wait for starts to complete
                await asyncio.wait(starts)

                for sim in sims:
                    ts.done(sim)

            await self.before_wait()
            await self.wait_for_sims()
        except:
            self.out.set_failed()
            traceback.print_exc()

        finally:
            self.out.set_end()

            # shut things back down
            if self.verbose:
                print('%s: cleaning up' % self.exp.name)

            await self.before_cleanup()

            # "interrupt, terminate, kill" all processes
            scs = []
            for _,sc in self.running:
                scs.append(sc.int_term_kill())
            await asyncio.wait(scs)

            # wait for all processes to terminate
            for _,sc in self.running:
                await sc.wait()

            # remove all sockets
            scs = []
            for (exec,sock) in self.sockets:
                scs.append(exec.rmtree(sock))
            await asyncio.wait(scs)

            # add all simulator components to the output
            for sim,sc in self.running:
                self.out.add_sim(sim, sc)

            await self.after_cleanup()
        return self.out


class ExperimentSimpleRunner(ExperimentBaseRunner):
    """ Simple experiment runner with just one executor. """
    def __init__(self, exec: Executor, *args, **kwargs):
        self.exec = exec
        super().__init__(*args, **kwargs)

    def sim_executor(self, sim: Simulator):
        return self.exec


class ExperimentDistributedRunner(ExperimentBaseRunner):
    """ Simple experiment runner with just one executor. """
    def __init__(self, execs, *args, **kwargs):
        self.execs = execs
        super().__init__(*args, **kwargs)
        assert self.exp.num_hosts <= len(execs)

    def sim_executor(self, sim):
        h_id = self.exp.host_mapping[sim]
        return self.execs[h_id]

    async def prepare(self):
        # make sure all simulators are assigned to an executor
        assert(self.exp.all_sims_assigned())

        # set IP addresses for proxies based on assigned executors
        for p in itertools.chain(
                self.exp.proxies_listen, self.exp.proxies_connect):
            exec = self.sim_executor(p)
            p.ip = exec.ip

        await super().prepare()
