scheduler.go 7.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//go:build mlx

package qwen_image

import (
	"math"

	"github.com/ollama/ollama/x/imagegen/mlx"
)

// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
	NumTrainTimesteps int32   `json:"num_train_timesteps"` // 1000
	BaseShift         float32 `json:"base_shift"`          // 0.5
	MaxShift          float32 `json:"max_shift"`           // 0.9
	BaseImageSeqLen   int32   `json:"base_image_seq_len"`  // 256
	MaxImageSeqLen    int32   `json:"max_image_seq_len"`   // 8192
	ShiftTerminal     float32 `json:"shift_terminal"`      // 0.02
	UseDynamicShift   bool    `json:"use_dynamic_shifting"` // true
}

// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
	return &SchedulerConfig{
		NumTrainTimesteps: 1000,
		BaseShift:         0.5,
		MaxShift:          0.9, // Matches scheduler_config.json
		BaseImageSeqLen:   256,
		MaxImageSeqLen:    8192,
		ShiftTerminal:     0.02,
		UseDynamicShift:   true,
	}
}

// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
	Config    *SchedulerConfig
	Timesteps []float32
	Sigmas    []float32
	NumSteps  int
}

// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
	return &FlowMatchScheduler{
		Config: cfg,
	}
}

// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
	m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
	b := baseShift - m*float32(baseSeqLen)
	mu := float32(imageSeqLen)*m + b
	return mu
}

// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
	s.NumSteps = numSteps

	// Calculate mu for dynamic shifting
	var mu float32
	if s.Config.UseDynamicShift {
		mu = CalculateShift(
			imageSeqLen,
			s.Config.BaseImageSeqLen,
			s.Config.MaxImageSeqLen,
			s.Config.BaseShift,
			s.Config.MaxShift,
		)
	}

	// Step 1: Create sigmas from 1.0 to 1/num_steps
	// Python (pipeline_qwenimage.py:639):
	//   sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
	// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
	sigmas := make([]float32, numSteps)
	sigmaMax := float32(1.0)
	sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
	if numSteps == 1 {
		sigmas[0] = sigmaMax
	} else {
		for i := 0; i < numSteps; i++ {
			sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
		}
	}

	// Step 2: Apply time shift if using dynamic shifting
	if s.Config.UseDynamicShift && mu != 0 {
		for i := range sigmas {
			sigmas[i] = s.timeShift(mu, sigmas[i])
		}
	}

	// Step 3: Apply stretch_shift_to_terminal
	if s.Config.ShiftTerminal > 0 {
		sigmas = s.stretchShiftToTerminal(sigmas)
	}

	// Step 4: Append terminal sigma (0) and store
	// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
	// before passing to transformer. We skip both steps and just use sigmas directly.
	s.Sigmas = make([]float32, numSteps+1)
	s.Timesteps = make([]float32, numSteps+1)
	for i := 0; i < numSteps; i++ {
		s.Sigmas[i] = sigmas[i]
		s.Timesteps[i] = sigmas[i]
	}
	s.Sigmas[numSteps] = 0.0
	s.Timesteps[numSteps] = 0.0
}

// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
	if len(sigmas) == 0 {
		return sigmas
	}

	// one_minus_z = 1 - t
	// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
	// stretched_t = 1 - (one_minus_z / scale_factor)
	lastSigma := sigmas[len(sigmas)-1]
	scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)

	// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
	// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
	if scaleFactor < 1e-6 {
		return sigmas
	}

	result := make([]float32, len(sigmas))
	for i, t := range sigmas {
		oneMinusZ := 1.0 - t
		result[i] = 1.0 - (oneMinusZ / scaleFactor)
	}
	return result
}

// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
	if t <= 0 {
		return 0
	}
	expMu := float32(math.Exp(float64(mu)))
	return expMu / (expMu + (1.0/t - 1.0))
}

// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
	// Get current and next sigma
	sigma := s.Sigmas[timestepIdx]
	sigmaNext := s.Sigmas[timestepIdx+1]

	// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
	dt := sigmaNext - sigma

	// Upcast to float32 to avoid precision issues (matches Python diffusers)
	sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
	modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)

	scaledOutput := mlx.MulScalar(modelOutputF32, dt)
	result := mlx.Add(sampleF32, scaledOutput)

	// Cast back to original dtype
	return mlx.ToBFloat16(result)
}

// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
	if idx < len(s.Timesteps) {
		return s.Timesteps[idx]
	}
	return 0.0
}

// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
	return mlx.RandomNormal(shape, uint64(seed))
}

// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
	shape := []int32{batchSize, seqLen, channels}
	return mlx.RandomNormal(shape, uint64(seed))
}

// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
	latentH := height / 8
	latentW := width / 8
	return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}

// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
	latentH := height / 8
	latentW := width / 8
	pH := latentH / patchSize
	pW := latentW / patchSize
	inChannels := int32(64) // 16 * patch_size^2
	return []int32{batchSize, pH * pW, inChannels}
}