Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
41be2809
Commit
41be2809
authored
Jul 10, 2024
by
Michael Yang
Browse files
add system prompt to first legacy template
parent
4e262eb2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
28 deletions
+140
-28
server/prompt_test.go
server/prompt_test.go
+1
-1
server/routes_create_test.go
server/routes_create_test.go
+2
-2
template/template.go
template/template.go
+90
-11
template/template_test.go
template/template_test.go
+47
-14
No files found.
server/prompt_test.go
View file @
41be2809
...
...
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what?
You are the Test Who Lived.
A test. And a thumping good one at that, I'd wager. "
,
prompt
:
"You
are the Test Who Lived. You
're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
}
...
...
server/routes_create_test.go
View file @
41be2809
...
...
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
checkFileExists
(
t
,
filepath
.
Join
(
p
,
"blobs"
,
"*"
),
[]
string
{
filepath
.
Join
(
p
,
"blobs"
,
"sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce5
7c
d
ac
893db8139
"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
68b0323b2f21572bc09ba07554b16b379a5713ee48ef8c25a7661a1f71cfce7
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
eb72fb7c550ee1f1dec4039bd65382acecf5f7536a30fb
7c
c
ac
e39a8d0cb590b
"
),
})
})
...
...
template/template.go
View file @
41be2809
...
...
@@ -143,11 +143,14 @@ func (t *Template) Vars() []string {
type
Values
struct
{
Messages
[]
api
.
Message
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy
bool
}
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
system
,
collated
:=
collate
(
v
.
Messages
)
if
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"Messages"
:
collated
,
...
...
@@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var
b
bytes
.
Buffer
var
prompt
,
response
string
for
i
,
m
:=
range
collated
{
if
m
.
Role
==
"user"
{
switch
m
.
Role
{
case
"user"
:
prompt
=
m
.
Content
}
else
{
if
i
!=
0
{
system
=
""
}
case
"assistant"
:
response
=
m
.
Content
}
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"System"
:
system
,
"Prompt"
:
prompt
,
"Response"
:
response
,
});
err
!=
nil
{
...
...
@@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
var
cut
bool
tree
:=
t
.
Template
.
Copy
()
// for the last message, cut everything after "{{ .Response }}"
tree
.
Root
.
Nodes
=
slices
.
DeleteFunc
(
tree
.
Root
.
Nodes
,
func
(
n
parse
.
Node
)
bool
{
if
slices
.
Contains
(
parseNode
(
n
),
"Response"
)
{
cut
=
true
nodes
:=
deleteNode
(
t
.
Template
.
Root
.
Copy
(),
func
(
n
parse
.
Node
)
bool
{
switch
t
:=
n
.
(
type
)
{
case
*
parse
.
ActionNode
:
case
*
parse
.
FieldNode
:
if
slices
.
Contains
(
t
.
Ident
,
"Response"
)
{
cut
=
true
}
}
return
cut
})
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
system
,
tree
:=
parse
.
Tree
{
Root
:
nodes
.
(
*
parse
.
ListNode
)}
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
&
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"Prompt"
:
prompt
,
});
err
!=
nil
{
return
err
...
...
@@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string {
return
nil
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func
deleteNode
(
n
parse
.
Node
,
fn
func
(
parse
.
Node
)
bool
)
parse
.
Node
{
var
walk
func
(
n
parse
.
Node
)
parse
.
Node
walk
=
func
(
n
parse
.
Node
)
parse
.
Node
{
if
fn
(
n
)
{
return
nil
}
switch
t
:=
n
.
(
type
)
{
case
*
parse
.
ListNode
:
var
nodes
[]
parse
.
Node
for
_
,
c
:=
range
t
.
Nodes
{
if
n
:=
walk
(
c
);
n
!=
nil
{
nodes
=
append
(
nodes
,
n
)
}
}
t
.
Nodes
=
nodes
return
t
case
*
parse
.
IfNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
WithNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
RangeNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
BranchNode
:
t
.
List
=
walk
(
t
.
List
)
.
(
*
parse
.
ListNode
)
if
t
.
ElseList
!=
nil
{
t
.
ElseList
=
walk
(
t
.
ElseList
)
.
(
*
parse
.
ListNode
)
}
case
*
parse
.
ActionNode
:
n
:=
walk
(
t
.
Pipe
)
if
n
==
nil
{
return
nil
}
t
.
Pipe
=
n
.
(
*
parse
.
PipeNode
)
case
*
parse
.
PipeNode
:
var
commands
[]
*
parse
.
CommandNode
for
_
,
c
:=
range
t
.
Cmds
{
var
args
[]
parse
.
Node
for
_
,
a
:=
range
c
.
Args
{
if
n
:=
walk
(
a
);
n
!=
nil
{
args
=
append
(
args
,
n
)
}
}
if
len
(
args
)
==
0
{
return
nil
}
c
.
Args
=
args
commands
=
append
(
commands
,
c
)
}
if
len
(
commands
)
==
0
{
return
nil
}
t
.
Cmds
=
commands
}
return
n
}
return
walk
(
n
)
}
template/template_test.go
View file @
41be2809
...
...
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
}
for
n
,
tt
:=
range
cases
{
var
actual
bytes
.
Buffer
t
.
Run
(
n
,
func
(
t
*
testing
.
T
)
{
var
actual
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
actual
,
Values
{
Messages
:
tt
});
err
!=
nil
{
t
.
Fatal
(
err
)
}
...
...
@@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) {
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
t
.
Run
(
"legacy"
,
func
(
t
*
testing
.
T
)
{
var
legacy
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
legacy
,
Values
{
Messages
:
tt
,
forceLegacy
:
true
});
err
!=
nil
{
t
.
Fatal
(
err
)
}
legacyBytes
:=
legacy
.
Bytes
()
if
slices
.
Contains
([]
string
{
"chatqa.gotmpl"
,
"openchat.gotmpl"
,
"vicuna.gotmpl"
},
match
)
&&
legacyBytes
[
len
(
legacyBytes
)
-
1
]
==
' '
{
t
.
Log
(
"removing trailing space from legacy output"
)
legacyBytes
=
legacyBytes
[
:
len
(
legacyBytes
)
-
1
]
}
else
if
slices
.
Contains
([]
string
{
"codellama-70b-instruct.gotmpl"
,
"llama2-chat.gotmpl"
,
"mistral-instruct.gotmpl"
},
match
)
{
t
.
Skip
(
"legacy outputs cannot be compared to messages outputs"
)
}
if
diff
:=
cmp
.
Diff
(
legacyBytes
,
actual
.
Bytes
());
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
})
}
...
...
@@ -136,6 +155,21 @@ func TestParse(t *testing.T) {
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
`{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`
,
[]
string
{
"content"
,
"messages"
,
"prompt"
,
"response"
,
"role"
,
"system"
}},
}
for
_
,
tt
:=
range
cases
{
...
...
@@ -145,9 +179,8 @@ func TestParse(t *testing.T) {
t
.
Fatal
(
err
)
}
vars
:=
tmpl
.
Vars
()
if
!
slices
.
Equal
(
tt
.
vars
,
vars
)
{
t
.
Errorf
(
"expected %v, got %v"
,
tt
.
vars
,
vars
)
if
diff
:=
cmp
.
Diff
(
tmpl
.
Vars
(),
tt
.
vars
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
...
...
@@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`
},
...
...
@@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`
},
...
...
@@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) {
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
`[INST]
Hello friend![/INST] Hello human![INST]
You are a helpful assistant!
`[INST] You are a helpful assistant!
What is your name?[/INST] `
,
Hello friend![/INST] Hello human![INST]
What is your name?[/INST] `
,
},
{
"chatml"
,
...
...
@@ -221,7 +254,7 @@ What is your name?[/INST] `,
`
},
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq
(len (slice $.Messages
$index
)) 1
) $.System }}<|im_start|>system
{{- if and (eq .Role "user") (eq $index
0
) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
...
...
@@ -236,12 +269,12 @@ What is your name?[/INST] `,
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
`<|im_start|>user
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
...
...
@@ -300,8 +333,8 @@ Answer: `,
t
.
Fatal
(
err
)
}
if
b
.
String
()
!=
tt
.
expected
{
t
.
Errorf
(
"
expected
\n
%s,
\n
got
\n
%s"
,
tt
.
expected
,
b
.
String
()
)
if
diff
:=
cmp
.
Diff
(
b
.
String
()
,
tt
.
expected
);
diff
!=
""
{
t
.
Errorf
(
"
mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment