database_test.go 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
//go:build windows || darwin

package store

import (
	"database/sql"
	"fmt"
	"os"
	"path/filepath"
	"sort"
	"strings"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	_ "github.com/mattn/go-sqlite3"
)

func TestSchemaMigrations(t *testing.T) {
	t.Run("schema comparison after migration", func(t *testing.T) {
		tmpDir := t.TempDir()
		migratedDBPath := filepath.Join(tmpDir, "migrated.db")
		migratedDB := loadV2Schema(t, migratedDBPath)
		defer migratedDB.Close()

		if err := migratedDB.migrate(); err != nil {
			t.Fatalf("migration failed: %v", err)
		}

		// Create fresh database with current schema
		freshDBPath := filepath.Join(tmpDir, "fresh.db")
		freshDB, err := newDatabase(freshDBPath)
		if err != nil {
			t.Fatalf("failed to create fresh database: %v", err)
		}
		defer freshDB.Close()

		// Extract tables and indexes from both databases, directly comparing their schemas won't work due to ordering
		migratedSchema := schemaMap(migratedDB)
		freshSchema := schemaMap(freshDB)

		if !cmp.Equal(migratedSchema, freshSchema) {
			t.Errorf("Schema difference found:\n%s", cmp.Diff(freshSchema, migratedSchema))
		}

		// Verify both databases have the same final schema version
		migratedVersion, _ := migratedDB.getSchemaVersion()
		freshVersion, _ := freshDB.getSchemaVersion()
		if migratedVersion != freshVersion {
			t.Errorf("schema version mismatch: migrated=%d, fresh=%d", migratedVersion, freshVersion)
		}
	})

	t.Run("idempotent migrations", func(t *testing.T) {
		tmpDir := t.TempDir()
		dbPath := filepath.Join(tmpDir, "test.db")
		db := loadV2Schema(t, dbPath)
		defer db.Close()

		// Run migration twice
		if err := db.migrate(); err != nil {
			t.Fatalf("first migration failed: %v", err)
		}

		if err := db.migrate(); err != nil {
			t.Fatalf("second migration failed: %v", err)
		}

		// Verify schema version is still correct
		version, err := db.getSchemaVersion()
		if err != nil {
			t.Fatalf("failed to get schema version: %v", err)
		}
		if version != currentSchemaVersion {
			t.Errorf("expected schema version %d after double migration, got %d", currentSchemaVersion, version)
		}
	})

	t.Run("init database has correct schema version", func(t *testing.T) {
		tmpDir := t.TempDir()
		dbPath := filepath.Join(tmpDir, "test.db")
		db, err := newDatabase(dbPath)
		if err != nil {
			t.Fatalf("failed to create database: %v", err)
		}
		defer db.Close()

		// Get the schema version from the newly initialized database
		version, err := db.getSchemaVersion()
		if err != nil {
			t.Fatalf("failed to get schema version: %v", err)
		}

		// Verify it matches the currentSchemaVersion constant
		if version != currentSchemaVersion {
			t.Errorf("expected schema version %d in initialized database, got %d", currentSchemaVersion, version)
		}
	})
}

func TestChatDeletionWithCascade(t *testing.T) {
	t.Run("chat deletion cascades to related messages", func(t *testing.T) {
		tmpDir := t.TempDir()
		dbPath := filepath.Join(tmpDir, "test.db")
		db, err := newDatabase(dbPath)
		if err != nil {
			t.Fatalf("failed to create database: %v", err)
		}
		defer db.Close()

		// Create test chat
		testChatID := "test-chat-cascade-123"
		testChat := Chat{
			ID:        testChatID,
			Title:     "Test Chat for Cascade Delete",
			CreatedAt: time.Now(),
			Messages: []Message{
				{
					Role:      "user",
					Content:   "Hello, this is a test message",
					CreatedAt: time.Now(),
					UpdatedAt: time.Now(),
				},
				{
					Role:      "assistant",
					Content:   "Hi there! This is a response.",
					CreatedAt: time.Now(),
					UpdatedAt: time.Now(),
				},
			},
		}

		// Save the chat with messages
		if err := db.saveChat(testChat); err != nil {
			t.Fatalf("failed to save test chat: %v", err)
		}

		// Verify chat and messages exist
		chatCount := countRows(t, db, "chats")
		messageCount := countRows(t, db, "messages")

		if chatCount != 1 {
			t.Errorf("expected 1 chat, got %d", chatCount)
		}
		if messageCount != 2 {
			t.Errorf("expected 2 messages, got %d", messageCount)
		}

		// Verify specific chat exists
		var exists bool
		err = db.conn.QueryRow("SELECT EXISTS(SELECT 1 FROM chats WHERE id = ?)", testChatID).Scan(&exists)
		if err != nil {
			t.Fatalf("failed to check chat existence: %v", err)
		}
		if !exists {
			t.Error("test chat should exist before deletion")
		}

		// Verify messages exist for this chat
		messageCountForChat := countRowsWithCondition(t, db, "messages", "chat_id = ?", testChatID)
		if messageCountForChat != 2 {
			t.Errorf("expected 2 messages for test chat, got %d", messageCountForChat)
		}

		// Delete the chat
		if err := db.deleteChat(testChatID); err != nil {
			t.Fatalf("failed to delete chat: %v", err)
		}

		// Verify chat is deleted
		chatCountAfter := countRows(t, db, "chats")
		if chatCountAfter != 0 {
			t.Errorf("expected 0 chats after deletion, got %d", chatCountAfter)
		}

		// Verify messages are CASCADE deleted
		messageCountAfter := countRows(t, db, "messages")
		if messageCountAfter != 0 {
			t.Errorf("expected 0 messages after CASCADE deletion, got %d", messageCountAfter)
		}

		// Verify specific chat no longer exists
		err = db.conn.QueryRow("SELECT EXISTS(SELECT 1 FROM chats WHERE id = ?)", testChatID).Scan(&exists)
		if err != nil {
			t.Fatalf("failed to check chat existence after deletion: %v", err)
		}
		if exists {
			t.Error("test chat should not exist after deletion")
		}

		// Verify no orphaned messages remain
		orphanedCount := countRowsWithCondition(t, db, "messages", "chat_id = ?", testChatID)
		if orphanedCount != 0 {
			t.Errorf("expected 0 orphaned messages, got %d", orphanedCount)
		}
	})

	t.Run("foreign keys are enabled", func(t *testing.T) {
		tmpDir := t.TempDir()
		dbPath := filepath.Join(tmpDir, "test.db")
		db, err := newDatabase(dbPath)
		if err != nil {
			t.Fatalf("failed to create database: %v", err)
		}
		defer db.Close()

		// Verify foreign keys are enabled
		var foreignKeysEnabled int
		err = db.conn.QueryRow("PRAGMA foreign_keys").Scan(&foreignKeysEnabled)
		if err != nil {
			t.Fatalf("failed to check foreign keys: %v", err)
		}
		if foreignKeysEnabled != 1 {
			t.Errorf("expected foreign keys to be enabled (1), got %d", foreignKeysEnabled)
		}
	})

	// This test is only relevant for v8 migrations, but we keep it here for now
	// since it's a useful test to ensure that we don't introduce any new orphaned data
	t.Run("cleanup orphaned data", func(t *testing.T) {
		tmpDir := t.TempDir()
		dbPath := filepath.Join(tmpDir, "test.db")
		db, err := newDatabase(dbPath)
		if err != nil {
			t.Fatalf("failed to create database: %v", err)
		}
		defer db.Close()

		// First disable foreign keys to simulate the bug from ollama/ollama#11785
		_, err = db.conn.Exec("PRAGMA foreign_keys = OFF")
		if err != nil {
			t.Fatalf("failed to disable foreign keys: %v", err)
		}

		// Create a chat and message
		testChatID := "orphaned-test-chat"
		testMessageID := int64(999)

		_, err = db.conn.Exec("INSERT INTO chats (id, title) VALUES (?, ?)", testChatID, "Orphaned Test Chat")
		if err != nil {
			t.Fatalf("failed to insert test chat: %v", err)
		}

		_, err = db.conn.Exec("INSERT INTO messages (id, chat_id, role, content) VALUES (?, ?, ?, ?)",
			testMessageID, testChatID, "user", "test message")
		if err != nil {
			t.Fatalf("failed to insert test message: %v", err)
		}

		// Delete chat but keep message (simulating the bug from ollama/ollama#11785)
		_, err = db.conn.Exec("DELETE FROM chats WHERE id = ?", testChatID)
		if err != nil {
			t.Fatalf("failed to delete chat: %v", err)
		}

		// Verify we have orphaned message
		orphanedCount := countRowsWithCondition(t, db, "messages", "chat_id = ?", testChatID)
		if orphanedCount != 1 {
			t.Errorf("expected 1 orphaned message, got %d", orphanedCount)
		}

		// Run cleanup
		if err := db.cleanupOrphanedData(); err != nil {
			t.Fatalf("failed to cleanup orphaned data: %v", err)
		}

		// Verify orphaned message is gone
		orphanedCountAfter := countRowsWithCondition(t, db, "messages", "chat_id = ?", testChatID)
		if orphanedCountAfter != 0 {
			t.Errorf("expected 0 orphaned messages after cleanup, got %d", orphanedCountAfter)
		}
	})
}

func countRows(t *testing.T, db *database, table string) int {
	t.Helper()
	var count int
	err := db.conn.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count)
	if err != nil {
		t.Fatalf("failed to count rows in %s: %v", table, err)
	}
	return count
}

func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...interface{}) int {
	t.Helper()
	var count int
	query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
	err := db.conn.QueryRow(query, args...).Scan(&count)
	if err != nil {
		t.Fatalf("failed to count rows with condition: %v", err)
	}
	return count
}

// Test helpers for schema migration testing

// schemaMap returns both tables/columns and indexes (ignoring order)
func schemaMap(db *database) map[string]interface{} {
	result := make(map[string]any)

	result["tables"] = columnMap(db)
	result["indexes"] = indexMap(db)

	return result
}

// columnMap returns a map of table names to their column sets (ignoring order)
func columnMap(db *database) map[string][]string {
	result := make(map[string][]string)

	// Get all table names
	tableQuery := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`
	rows, _ := db.conn.Query(tableQuery)
	defer rows.Close()

	for rows.Next() {
		var tableName string
		rows.Scan(&tableName)

		// Get columns for this table
		colQuery := fmt.Sprintf("PRAGMA table_info(%s)", tableName)
		colRows, _ := db.conn.Query(colQuery)

		var columns []string
		for colRows.Next() {
			var cid int
			var name, dataType sql.NullString
			var notNull, primaryKey int
			var defaultValue sql.NullString

			colRows.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &primaryKey)

			// Create a normalized column description
			colDesc := fmt.Sprintf("%s %s", name.String, dataType.String)
			if notNull == 1 {
				colDesc += " NOT NULL"
			}
			if defaultValue.Valid && defaultValue.String != "" {
				// Skip DEFAULT for schema_version as it doesn't get updated during migrations
				if name.String != "schema_version" {
					colDesc += " DEFAULT " + defaultValue.String
				}
			}
			if primaryKey == 1 {
				colDesc += " PRIMARY KEY"
			}

			columns = append(columns, colDesc)
		}
		colRows.Close()

		// Sort columns to ignore order differences
		sort.Strings(columns)
		result[tableName] = columns
	}

	return result
}

// indexMap returns a map of index names to their definitions
func indexMap(db *database) map[string]string {
	result := make(map[string]string)

	// Get all indexes (excluding auto-created primary key indexes)
	indexQuery := `SELECT name, sql FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%' AND sql IS NOT NULL ORDER BY name`
	rows, _ := db.conn.Query(indexQuery)
	defer rows.Close()

	for rows.Next() {
		var name, sql string
		rows.Scan(&name, &sql)

		// Normalize the SQL by removing extra whitespace
		sql = strings.Join(strings.Fields(sql), " ")
		result[name] = sql
	}

	return result
}

// loadV2Schema loads the version 2 schema from testdata/schema.sql
func loadV2Schema(t *testing.T, dbPath string) *database {
	t.Helper()

	// Read the v1 schema file
	schemaFile := filepath.Join("testdata", "schema.sql")
	schemaSQL, err := os.ReadFile(schemaFile)
	if err != nil {
		t.Fatalf("failed to read schema file: %v", err)
	}

	// Open database connection
	conn, err := sql.Open("sqlite3", dbPath+"?_foreign_keys=on&_journal_mode=WAL&_busy_timeout=5000&_txlock=immediate")
	if err != nil {
		t.Fatalf("failed to open database: %v", err)
	}

	// Execute the v1 schema
	_, err = conn.Exec(string(schemaSQL))
	if err != nil {
		conn.Close()
		t.Fatalf("failed to execute v1 schema: %v", err)
	}

	return &database{conn: conn}
}