Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
36c87c43
Commit
36c87c43
authored
Jul 12, 2024
by
Michael Yang
Browse files
template: preprocess message and collect system
parent
179737fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
67 deletions
+23
-67
template/template.go
template/template.go
+15
-22
template/template_test.go
template/template_test.go
+8
-45
No files found.
template/template.go
View file @
36c87c43
...
@@ -102,22 +102,8 @@ var response = parse.ActionNode{
...
@@ -102,22 +102,8 @@ var response = parse.ActionNode{
},
},
}
}
var
funcs
=
template
.
FuncMap
{
// contents returns the contents of messages with an optional role filter
"contents"
:
func
(
v
[]
*
api
.
Message
,
role
...
string
)
string
{
var
parts
[]
string
for
_
,
m
:=
range
v
{
if
len
(
role
)
==
0
||
role
[
0
]
==
""
||
m
.
Role
==
role
[
0
]
{
parts
=
append
(
parts
,
m
.
Content
)
}
}
return
strings
.
Join
(
parts
,
"
\n\n
"
)
},
}
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
func
Parse
(
s
string
)
(
*
Template
,
error
)
{
tmpl
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
.
Funcs
(
funcs
)
tmpl
:=
template
.
New
(
""
)
.
Option
(
"missingkey=zero"
)
tmpl
,
err
:=
tmpl
.
Parse
(
s
)
tmpl
,
err
:=
tmpl
.
Parse
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -163,15 +149,16 @@ type Values struct {
...
@@ -163,15 +149,16 @@ type Values struct {
}
}
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
collated
:=
collate
(
v
.
Messages
)
system
,
collated
:=
collate
(
v
.
Messages
)
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"Messages"
:
collated
,
"Messages"
:
collated
,
})
})
}
}
var
b
bytes
.
Buffer
var
b
bytes
.
Buffer
var
system
,
prompt
,
response
string
var
prompt
,
response
string
for
i
,
m
:=
range
collated
{
for
i
,
m
:=
range
collated
{
switch
m
.
Role
{
switch
m
.
Role
{
case
"system"
:
case
"system"
:
...
@@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
}
// collate messages based on role. consecutive messages of the same role are merged
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// into a single message. collate also collects and returns all system messages.
// which are templated separately. As a side effect, it mangles message content adding image
// collate mutates message content adding image tags ([img-%d]) as needed
// tags ([img-%d]) as needed
func
collate
(
msgs
[]
api
.
Message
)
(
string
,
[]
*
api
.
Message
)
{
func
collate
(
msgs
[]
api
.
Message
)
(
collated
[]
*
api
.
Message
)
{
var
n
int
var
n
int
var
system
[]
string
var
collated
[]
*
api
.
Message
for
i
:=
range
msgs
{
for
i
:=
range
msgs
{
msg
:=
msgs
[
i
]
msg
:=
msgs
[
i
]
for
range
msg
.
Images
{
for
range
msg
.
Images
{
...
@@ -240,6 +229,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
...
@@ -240,6 +229,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
n
++
n
++
}
}
if
msg
.
Role
==
"system"
{
system
=
append
(
system
,
msg
.
Content
)
}
if
len
(
collated
)
>
0
&&
collated
[
len
(
collated
)
-
1
]
.
Role
==
msg
.
Role
{
if
len
(
collated
)
>
0
&&
collated
[
len
(
collated
)
-
1
]
.
Role
==
msg
.
Role
{
collated
[
len
(
collated
)
-
1
]
.
Content
+=
"
\n\n
"
+
msg
.
Content
collated
[
len
(
collated
)
-
1
]
.
Content
+=
"
\n\n
"
+
msg
.
Content
}
else
{
}
else
{
...
@@ -247,7 +240,7 @@ func collate(msgs []api.Message) (collated []*api.Message) {
...
@@ -247,7 +240,7 @@ func collate(msgs []api.Message) (collated []*api.Message) {
}
}
}
}
return
return
strings
.
Join
(
system
,
"
\n\n
"
),
collated
}
}
func
parseNode
(
n
parse
.
Node
)
[]
string
{
func
parseNode
(
n
parse
.
Node
)
[]
string
{
...
...
template/template_test.go
View file @
36c87c43
...
@@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
{
"response"
,
`[INST] {{ if .System }}{{ .System }}
{
"response"
,
`[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- $system := contents .Messages "system" -}}
{
"messages"
,
`[INST] {{ if .System }}{{ .System }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{ end }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`
},
{{- end }}`
},
},
},
Values
{
Values
{
...
@@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
{
"response"
,
`[INST] {{ if .System }}{{ .System }}
{
"response"
,
`[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- $system := contents .Messages "system" -}}
{
"messages"
,
`[INST] {{ if .System }}{{ .System }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{ end }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`
},
{{- end }}`
},
},
},
Values
{
Values
{
...
@@ -363,36 +359,3 @@ Answer: `,
...
@@ -363,36 +359,3 @@ Answer: `,
})
})
}
}
}
}
func
TestFuncs
(
t
*
testing
.
T
)
{
t
.
Run
(
"contents"
,
func
(
t
*
testing
.
T
)
{
cases
:=
map
[
string
]
string
{
""
:
"A
\n\n
B
\n\n
C
\n\n
D
\n\n
E
\n\n
F"
,
"system"
:
"A
\n\n
F"
,
"user"
:
"B
\n\n
E"
,
"assistant"
:
"C
\n\n
D"
,
}
s
:=
[]
*
api
.
Message
{
{
Role
:
"system"
,
Content
:
"A"
},
{
Role
:
"user"
,
Content
:
"B"
},
{
Role
:
"assistant"
,
Content
:
"C"
},
{
Role
:
"assistant"
,
Content
:
"D"
},
{
Role
:
"user"
,
Content
:
"E"
},
{
Role
:
"system"
,
Content
:
"F"
},
}
fn
,
ok
:=
funcs
[
"contents"
]
.
(
func
([]
*
api
.
Message
,
...
string
)
string
)
if
!
ok
{
t
.
Fatal
(
"contents is not a function"
)
}
for
k
,
v
:=
range
cases
{
t
.
Run
(
k
,
func
(
t
*
testing
.
T
)
{
if
diff
:=
cmp
.
Diff
(
fn
(
s
,
k
),
v
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
})
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment