import os
import tempfile
from typing import Type, Iterable
from contextlib import contextmanager
from shutil import copytree

import torch
import torchaudio
import pytest

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends


def get_asset_path(*paths):
    """Return full path of a test asset"""
    return os.path.join(_TEST_DIR_PATH, 'assets', *paths)


def create_temp_assets_dir():
    """
    Creates a temporary directory and moves all files from test/assets there.
    Returns a Tuple[string, TemporaryDirectory] which is the folder path
    and object.
    """
    tmp_dir = tempfile.TemporaryDirectory()
    copytree(os.path.join(_TEST_DIR_PATH, "assets"),
             os.path.join(tmp_dir.name, "assets"))
    return tmp_dir.name, tmp_dir


def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
    """ Generates random tensors given a seed and size
    https://en.wikipedia.org/wiki/Linear_congruential_generator
    X_{n + 1} = (a * X_n + c) % m
    Using Borland C/C++ values

    The tensor will have values between [0,1)
    Inputs:
        seed (int): an int
        size (Tuple[int]): the size of the output tensor
        a (int): the multiplier constant to the generator
        c (int): the additive constant to the generator
        m (int): the modulus constant to the generator
    """
    num_elements = 1
    for s in size:
        num_elements *= s

    arr = [(a * seed + c) % m]
    for i in range(num_elements - 1):
        arr.append((a * arr[i] + c) % m)

    return torch.tensor(arr).float().view(size) / m


@contextmanager
def AudioBackendScope(new_backend):
    previous_backend = torchaudio.get_audio_backend()
    try:
        torchaudio.set_audio_backend(new_backend)
        yield
    finally:
        torchaudio.set_audio_backend(previous_backend)


def filter_backends_with_mp3(backends):
    # Filter out backends that do not support mp3
    test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')

    def supports_mp3(backend):
        try:
            with AudioBackendScope(backend):
                torchaudio.load(test_filepath)
            return True
        except (RuntimeError, ImportError):
            return False

    return [backend for backend in backends if supports_mp3(backend)]


BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)


class TestBaseMixin:
    dtype = None
    device = None


def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
    if dtype not in ['float32', 'float64']:
        raise NotImplementedError(f'Unexpected dtype: {dtype}')
    if device not in ['cpu', 'cuda']:
        raise NotImplementedError(f'Unexpected device: {device}')

    name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
    attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
    testsuite = type(name, (testbase,), attrs)

    if device == 'cuda':
        testsuite = pytest.mark.skipif(
            not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
    return testsuite


def define_test_suites(
        scope: dict,
        testbases: Iterable[Type[TestBaseMixin]],
        dtypes: Iterable[str] = ('float32', 'float64'),
        devices: Iterable[str] = ('cpu', 'cuda'),
):
    for suite in testbases:
        for device in devices:
            for dtype in dtypes:
                t = define_test_suite(suite, dtype, device)
                scope[t.__name__] = t
