import pytest
import torch
from torch import Tensor
from torch_complex import ComplexTensor

from espnet2.enh.separator.dptnet_separator import DPTNetSeparator


@pytest.mark.parametrize("input_dim", [8])
@pytest.mark.parametrize("post_enc_relu", [True, False])
@pytest.mark.parametrize("rnn_type", ["lstm", "gru"])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("num_spk", [1, 2])
@pytest.mark.parametrize("unit", [8])
@pytest.mark.parametrize("att_heads", [4])
@pytest.mark.parametrize("dropout", [0.2])
@pytest.mark.parametrize("activation", ["relu"])
@pytest.mark.parametrize("norm_type", ["gLN"])
@pytest.mark.parametrize("layer", [1, 3])
@pytest.mark.parametrize("segment_size", [2, 4])
@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"])
def test_dptnet_separator_forward_backward_complex(
    input_dim,
    post_enc_relu,
    rnn_type,
    bidirectional,
    num_spk,
    unit,
    att_heads,
    dropout,
    activation,
    norm_type,
    layer,
    segment_size,
    nonlinear,
):
    model = DPTNetSeparator(
        input_dim=input_dim,
        post_enc_relu=post_enc_relu,
        rnn_type=rnn_type,
        bidirectional=bidirectional,
        num_spk=num_spk,
        unit=unit,
        att_heads=att_heads,
        dropout=dropout,
        activation=activation,
        norm_type=norm_type,
        layer=layer,
        segment_size=segment_size,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()


@pytest.mark.parametrize("input_dim", [8])
@pytest.mark.parametrize("post_enc_relu", [True, False])
@pytest.mark.parametrize("rnn_type", ["lstm", "gru"])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("num_spk", [1, 2])
@pytest.mark.parametrize("unit", [8])
@pytest.mark.parametrize("att_heads", [4])
@pytest.mark.parametrize("dropout", [0.2])
@pytest.mark.parametrize("activation", ["relu"])
@pytest.mark.parametrize("norm_type", ["gLN"])
@pytest.mark.parametrize("layer", [1, 3])
@pytest.mark.parametrize("segment_size", [2, 4])
@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"])
def test_dptnet_separator_forward_backward_real(
    input_dim,
    post_enc_relu,
    rnn_type,
    bidirectional,
    num_spk,
    unit,
    att_heads,
    dropout,
    activation,
    norm_type,
    layer,
    segment_size,
    nonlinear,
):
    model = DPTNetSeparator(
        input_dim=input_dim,
        post_enc_relu=post_enc_relu,
        rnn_type=rnn_type,
        bidirectional=bidirectional,
        num_spk=num_spk,
        unit=unit,
        att_heads=att_heads,
        dropout=dropout,
        activation=activation,
        norm_type=norm_type,
        layer=layer,
        segment_size=segment_size,
        nonlinear=nonlinear,
    )
    model.train()

    x = torch.rand(2, 10, input_dim)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], Tensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()


def test_dptnet_separator_invalid_args():
    with pytest.raises(ValueError):
        DPTNetSeparator(
            input_dim=8,
            rnn_type="rnn",
            num_spk=2,
            unit=10,
            dropout=0.1,
            layer=2,
            segment_size=2,
            nonlinear="fff",
        )
    with pytest.raises(AssertionError):
        DPTNetSeparator(
            input_dim=10,
            rnn_type="rnn",
            num_spk=2,
            unit=10,
            att_heads=4,
            dropout=0.1,
            layer=2,
            segment_size=2,
            nonlinear="relu",
        )


def test_dptnet_separator_output():
    x = torch.rand(2, 10, 8)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DPTNetSeparator(
            input_dim=8,
            rnn_type="rnn",
            num_spk=2,
            unit=10,
            dropout=0.1,
            layer=2,
            segment_size=2,
            nonlinear="relu",
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        assert x.shape == specs[0].shape
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
