"vscode:/vscode.git/clone" did not exist on "b1f0fc1c0b586c9e9e40c3f7ae96ec46efe1d666"
test_custom_stft.py 2.67 KB
Newer Older
guobj's avatar
init  
guobj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import numpy as np
import pytest
from kokoro.custom_stft import CustomSTFT
from kokoro.istftnet import TorchSTFT
import torch.nn.functional as F


@pytest.fixture
def sample_audio():
    # Generate a sample audio signal (sine wave)
    sample_rate = 16000
    duration = 1.0  # seconds
    t = torch.linspace(0, duration, int(sample_rate * duration))
    frequency = 440.0  # Hz
    signal = torch.sin(2 * np.pi * frequency * t)
    return signal.unsqueeze(0)  # Add batch dimension


def test_stft_reconstruction(sample_audio):
    # Initialize both STFT implementations
    custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
    torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)

    # Process through both implementations
    custom_output = custom_stft(sample_audio)
    torch_output = torch_stft(sample_audio)

    # Compare outputs
    assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)


def test_magnitude_phase_consistency(sample_audio):
    custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
    torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)

    # Get magnitude and phase from both implementations
    custom_mag, custom_phase = custom_stft.transform(sample_audio)
    torch_mag, torch_phase = torch_stft.transform(sample_audio)

    # Compare magnitudes ignoring the boundary frames
    custom_mag_center = custom_mag[..., 2:-2]
    torch_mag_center = torch_mag[..., 2:-2]
    assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)


def test_batch_processing():
    # Create a batch of signals
    batch_size = 4
    sample_rate = 16000
    duration = 0.1  # shorter duration for faster testing
    t = torch.linspace(0, duration, int(sample_rate * duration))
    frequency = 440.0
    signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)

    custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)

    # Process batch
    output = custom_stft(signals)

    # Check output shape
    assert output.shape[0] == batch_size
    assert len(output.shape) == 3  # (batch, 1, time)


def test_different_window_sizes():
    signal = torch.randn(1, 16000)  # 1 second of random noise

    # Test with different window sizes
    for filter_length in [512, 1024, 2048]:
        custom_stft = CustomSTFT(
            filter_length=filter_length,
            hop_length=filter_length // 4,
            win_length=filter_length,
        )

        # Forward and backward transform
        output = custom_stft(signal)

        # Check that output length is reasonable
        assert output.shape[-1] >= signal.shape[-1]