llama_test.go 2.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package llm

import (
	"bytes"
	"testing"
)

func TestCheckStopConditions(t *testing.T) {
	tests := map[string]struct {
		b                      string
		stop                   []string
		wantB                  string
		wantStop               bool
		wantEndsWithStopPrefix bool
	}{
		"not present": {
			b:                      "abc",
			stop:                   []string{"x"},
			wantStop:               false,
			wantEndsWithStopPrefix: false,
		},
		"exact": {
			b:                      "abc",
			stop:                   []string{"abc"},
			wantStop:               true,
			wantEndsWithStopPrefix: false,
		},
		"substring": {
			b:                      "abc",
			stop:                   []string{"b"},
			wantB:                  "a",
			wantStop:               true,
			wantEndsWithStopPrefix: false,
		},
		"prefix 1": {
			b:                      "abc",
			stop:                   []string{"abcd"},
			wantStop:               false,
			wantEndsWithStopPrefix: true,
		},
		"prefix 2": {
			b:                      "abc",
			stop:                   []string{"bcd"},
			wantStop:               false,
			wantEndsWithStopPrefix: true,
		},
		"prefix 3": {
			b:                      "abc",
			stop:                   []string{"cd"},
			wantStop:               false,
			wantEndsWithStopPrefix: true,
		},
		"no prefix": {
			b:                      "abc",
			stop:                   []string{"bx"},
			wantStop:               false,
			wantEndsWithStopPrefix: false,
		},
	}
	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			var b bytes.Buffer
			b.WriteString(test.b)
			stop, endsWithStopPrefix := handleStopSequences(&b, test.stop)
			if test.wantB != "" {
				gotB := b.String()
				if gotB != test.wantB {
					t.Errorf("got b %q, want %q", gotB, test.wantB)
				}
			}
			if stop != test.wantStop {
				t.Errorf("got stop %v, want %v", stop, test.wantStop)
			}
			if endsWithStopPrefix != test.wantEndsWithStopPrefix {
				t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix)
			}
		})
	}
}