Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
41be2809
Commit
41be2809
authored
Jul 10, 2024
by
Michael Yang
Browse files
add system prompt to first legacy template
parent
4e262eb2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
28 deletions
+140
-28
server/prompt_test.go
server/prompt_test.go
+1
-1
server/routes_create_test.go
server/routes_create_test.go
+2
-2
template/template.go
template/template.go
+90
-11
template/template_test.go
template/template_test.go
+47
-14
No files found.
server/prompt_test.go
View file @
41be2809
...
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
},
expect
:
expect
{
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what?
You are the Test Who Lived.
A test. And a thumping good one at that, I'd wager. "
,
prompt
:
"You
are the Test Who Lived. You
're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
},
},
}
}
...
...
server/routes_create_test.go
View file @
41be2809
...
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
...
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
checkFileExists
(
t
,
filepath
.
Join
(
p
,
"blobs"
,
"*"
),
[]
string
{
checkFileExists
(
t
,
filepath
.
Join
(
p
,
"blobs"
,
"*"
),
[]
string
{
filepath
.
Join
(
p
,
"blobs"
,
"sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
68b0323b2f21572bc09ba07554b16b379a5713ee48ef8c25a7661a1f71cfce7
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce5
7c
d
ac
893db8139
"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
eb72fb7c550ee1f1dec4039bd65382acecf5f7536a30fb
7c
c
ac
e39a8d0cb590b
"
),
})
})
})
})
...
...
template/template.go
View file @
41be2809
...
@@ -143,11 +143,14 @@ func (t *Template) Vars() []string {
...
@@ -143,11 +143,14 @@ func (t *Template) Vars() []string {
type
Values
struct
{
type
Values
struct
{
Messages
[]
api
.
Message
Messages
[]
api
.
Message
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy
bool
}
}
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
system
,
collated
:=
collate
(
v
.
Messages
)
system
,
collated
:=
collate
(
v
.
Messages
)
if
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"System"
:
system
,
"Messages"
:
collated
,
"Messages"
:
collated
,
...
@@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var
b
bytes
.
Buffer
var
b
bytes
.
Buffer
var
prompt
,
response
string
var
prompt
,
response
string
for
i
,
m
:=
range
collated
{
for
i
,
m
:=
range
collated
{
if
m
.
Role
==
"user"
{
switch
m
.
Role
{
case
"user"
:
prompt
=
m
.
Content
prompt
=
m
.
Content
}
else
{
if
i
!=
0
{
system
=
""
}
case
"assistant"
:
response
=
m
.
Content
response
=
m
.
Content
}
}
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"System"
:
system
,
"Prompt"
:
prompt
,
"Prompt"
:
prompt
,
"Response"
:
response
,
"Response"
:
response
,
});
err
!=
nil
{
});
err
!=
nil
{
...
@@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
}
var
cut
bool
var
cut
bool
tree
:=
t
.
Template
.
Copy
()
nodes
:=
deleteNode
(
t
.
Template
.
Root
.
Copy
(),
func
(
n
parse
.
Node
)
bool
{
// for the last message, cut everything after "{{ .Response }}"
switch
t
:=
n
.
(
type
)
{
tree
.
Root
.
Nodes
=
slices
.
DeleteFunc
(
tree
.
Root
.
Nodes
,
func
(
n
parse
.
Node
)
bool
{
case
*
parse
.
ActionNode
:
if
slices
.
Contains
(
parseNode
(
n
),
"Response"
)
{
case
*
parse
.
FieldNode
:
cut
=
true
if
slices
.
Contains
(
t
.
Ident
,
"Response"
)
{
cut
=
true
}
}
}
return
cut
return
cut
})
})
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
tree
:=
parse
.
Tree
{
Root
:
nodes
.
(
*
parse
.
ListNode
)}
"System"
:
system
,
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
&
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"Prompt"
:
prompt
,
"Prompt"
:
prompt
,
});
err
!=
nil
{
});
err
!=
nil
{
return
err
return
err
...
@@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string {
...
@@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string {
return
nil
return
nil
}
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func
deleteNode
(
n
parse
.
Node
,
fn
func
(
parse
.
Node
)
bool
)
parse
.
Node
{
var
walk
func
(
n
parse
.
Node
)
parse
.
Node
walk
=
func
(
n
parse
.
Node
)
parse
.
Node
{
if
fn
(
n
)
{
return
nil
}
switch
t
:=
n
.
(
type
)
{
case
*
parse
.
ListNode
:
var
nodes
[]
parse
.
Node
for
_
,
c
:=
range
t
.
Nodes
{
if
n
:=
walk
(
c
);
n
!=
nil
{
nodes
=
append
(
nodes
,
n
)
}
}
t
.
Nodes
=
nodes
return
t
case
*
parse
.
IfNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
WithNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
RangeNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
BranchNode
:
t
.
List
=
walk
(
t
.
List
)
.
(
*
parse
.
ListNode
)
if
t
.
ElseList
!=
nil
{
t
.
ElseList
=
walk
(
t
.
ElseList
)
.
(
*
parse
.
ListNode
)
}
case
*
parse
.
ActionNode
:
n
:=
walk
(
t
.
Pipe
)
if
n
==
nil
{
return
nil
}
t
.
Pipe
=
n
.
(
*
parse
.
PipeNode
)
case
*
parse
.
PipeNode
:
var
commands
[]
*
parse
.
CommandNode
for
_
,
c
:=
range
t
.
Cmds
{
var
args
[]
parse
.
Node
for
_
,
a
:=
range
c
.
Args
{
if
n
:=
walk
(
a
);
n
!=
nil
{
args
=
append
(
args
,
n
)
}
}
if
len
(
args
)
==
0
{
return
nil
}
c
.
Args
=
args
commands
=
append
(
commands
,
c
)
}
if
len
(
commands
)
==
0
{
return
nil
}
t
.
Cmds
=
commands
}
return
n
}
return
walk
(
n
)
}
template/template_test.go
View file @
41be2809
...
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
...
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
}
}
for
n
,
tt
:=
range
cases
{
for
n
,
tt
:=
range
cases
{
var
actual
bytes
.
Buffer
t
.
Run
(
n
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
n
,
func
(
t
*
testing
.
T
)
{
var
actual
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
actual
,
Values
{
Messages
:
tt
});
err
!=
nil
{
if
err
:=
tmpl
.
Execute
(
&
actual
,
Values
{
Messages
:
tt
});
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
...
@@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) {
...
@@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) {
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"legacy"
,
func
(
t
*
testing
.
T
)
{
var
legacy
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
legacy
,
Values
{
Messages
:
tt
,
forceLegacy
:
true
});
err
!=
nil
{
t
.
Fatal
(
err
)
}
legacyBytes
:=
legacy
.
Bytes
()
if
slices
.
Contains
([]
string
{
"chatqa.gotmpl"
,
"openchat.gotmpl"
,
"vicuna.gotmpl"
},
match
)
&&
legacyBytes
[
len
(
legacyBytes
)
-
1
]
==
' '
{
t
.
Log
(
"removing trailing space from legacy output"
)
legacyBytes
=
legacyBytes
[
:
len
(
legacyBytes
)
-
1
]
}
else
if
slices
.
Contains
([]
string
{
"codellama-70b-instruct.gotmpl"
,
"llama2-chat.gotmpl"
,
"mistral-instruct.gotmpl"
},
match
)
{
t
.
Skip
(
"legacy outputs cannot be compared to messages outputs"
)
}
if
diff
:=
cmp
.
Diff
(
legacyBytes
,
actual
.
Bytes
());
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
})
})
}
}
...
@@ -136,6 +155,21 @@ func TestParse(t *testing.T) {
...
@@ -136,6 +155,21 @@ func TestParse(t *testing.T) {
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
`{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`
,
[]
string
{
"content"
,
"messages"
,
"prompt"
,
"response"
,
"role"
,
"system"
}},
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
...
@@ -145,9 +179,8 @@ func TestParse(t *testing.T) {
...
@@ -145,9 +179,8 @@ func TestParse(t *testing.T) {
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
vars
:=
tmpl
.
Vars
()
if
diff
:=
cmp
.
Diff
(
tmpl
.
Vars
(),
tt
.
vars
);
diff
!=
""
{
if
!
slices
.
Equal
(
tt
.
vars
,
vars
)
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"expected %v, got %v"
,
tt
.
vars
,
vars
)
}
}
})
})
}
}
...
@@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
},
{{- end }}`
},
...
@@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
},
{{- end }}`
},
...
@@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) {
{
Role
:
"user"
,
Content
:
"What is your name?"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
},
},
`[INST]
Hello friend![/INST] Hello human![INST]
You are a helpful assistant!
`[INST] You are a helpful assistant!
What is your name?[/INST] `
,
Hello friend![/INST] Hello human![INST]
What is your name?[/INST] `
,
},
},
{
{
"chatml"
,
"chatml"
,
...
@@ -221,7 +254,7 @@ What is your name?[/INST] `,
...
@@ -221,7 +254,7 @@ What is your name?[/INST] `,
`
},
`
},
{
"messages"
,
`
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq
(len (slice $.Messages
$index
)) 1
) $.System }}<|im_start|>system
{{- if and (eq .Role "user") (eq $index
0
) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{ .Content }}<|im_end|>{{ "\n" }}
...
@@ -236,12 +269,12 @@ What is your name?[/INST] `,
...
@@ -236,12 +269,12 @@ What is your name?[/INST] `,
{
Role
:
"user"
,
Content
:
"What is your name?"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
},
},
`<|im_start|>user
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Hello friend!<|im_end|>
Hello friend!<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
Hello human!<|im_end|>
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
<|im_start|>user
What is your name?<|im_end|>
What is your name?<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
...
@@ -300,8 +333,8 @@ Answer: `,
...
@@ -300,8 +333,8 @@ Answer: `,
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
if
b
.
String
()
!=
tt
.
expected
{
if
diff
:=
cmp
.
Diff
(
b
.
String
()
,
tt
.
expected
);
diff
!=
""
{
t
.
Errorf
(
"
expected
\n
%s,
\n
got
\n
%s"
,
tt
.
expected
,
b
.
String
()
)
t
.
Errorf
(
"
mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment