scheduler_test.go 4.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//go:build mlx

package qwen_image

import (
	"math"
	"testing"
)

// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
//	python3 -c "
//	from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
//	import numpy as np
//	s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
//	    base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
//	mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
//	sigmas = np.linspace(1.0, 1.0/30, 30)
//	s.set_timesteps(sigmas=sigmas, mu=mu)
//	print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
	cfg := DefaultSchedulerConfig()
	scheduler := NewFlowMatchScheduler(cfg)
	scheduler.SetTimesteps(30, 4096)

	// Golden values from Python diffusers (first 3, last 3 before terminal)
	wantFirst := []float32{1.000000, 0.982251, 0.963889}
	wantLast := []float32{0.142924, 0.083384, 0.020000}

	// Check first 3
	for i, want := range wantFirst {
		got := scheduler.Sigmas[i]
		if abs32(got-want) > 1e-4 {
			t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
		}
	}

	// Check last 3 (indices 27, 28, 29)
	for i, want := range wantLast {
		idx := 27 + i
		got := scheduler.Sigmas[idx]
		if abs32(got-want) > 1e-4 {
			t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
		}
	}

	// Check terminal is 0
	if scheduler.Sigmas[30] != 0.0 {
		t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
	}

	// Check length
	if len(scheduler.Sigmas) != 31 {
		t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
	}
}

// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
	cfg := DefaultSchedulerConfig()
	scheduler := NewFlowMatchScheduler(cfg)
	scheduler.SetTimesteps(30, 4096)

	// Property: sigmas monotonically decreasing
	for i := 1; i < len(scheduler.Sigmas); i++ {
		if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
			t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
				i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
		}
	}

	// Property: first sigma should be ~1.0 (with time shift)
	if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
		t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
	}

	// Property: terminal sigma should be exactly 0
	if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
		t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
	}

	// Property: last non-terminal sigma should be shift_terminal (0.02)
	lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
	if abs32(lastNonTerminal-0.02) > 1e-5 {
		t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
	}

	// Property: length = steps + 1
	if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
		t.Errorf("sigmas length should be steps+1: got %d, want %d",
			len(scheduler.Sigmas), scheduler.NumSteps+1)
	}
}

// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
	cases := []struct {
		imgSeqLen int32
		want      float32
	}{
		{256, 0.5},     // base case
		{8192, 0.9},    // max case
		{4096, 0.6935}, // middle case (rounded)
	}

	for _, c := range cases {
		got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
		if abs32(got-c.want) > 0.001 {
			t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
		}
	}
}

// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
	cfg := DefaultSchedulerConfig()
	scheduler := NewFlowMatchScheduler(cfg)
	scheduler.SetTimesteps(30, 4096)

	// Verify dt calculation for first step
	sigma0 := scheduler.Sigmas[0]
	sigma1 := scheduler.Sigmas[1]
	expectedDt := sigma1 - sigma0

	// dt should be negative (sigmas decrease)
	if expectedDt >= 0 {
		t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
	}
}

func abs32(x float32) float32 {
	return float32(math.Abs(float64(x)))
}