Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
41be2809
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dcb6dd9b7a6c7ddd6875506f40597c0976fd02c5"
Commit
41be2809
authored
Jul 10, 2024
by
Michael Yang
Browse files
add system prompt to first legacy template
parent
4e262eb2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
28 deletions
+140
-28
server/prompt_test.go
server/prompt_test.go
+1
-1
server/routes_create_test.go
server/routes_create_test.go
+2
-2
template/template.go
template/template.go
+90
-11
template/template_test.go
template/template_test.go
+47
-14
No files found.
server/prompt_test.go
View file @
41be2809
...
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
...
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
{
Role
:
"user"
,
Content
:
"A test. And a thumping good one at that, I'd wager."
},
},
},
expect
:
expect
{
expect
:
expect
{
prompt
:
"You're a test, Harry! I-I'm a what?
You are the Test Who Lived.
A test. And a thumping good one at that, I'd wager. "
,
prompt
:
"You
are the Test Who Lived. You
're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. "
,
},
},
},
},
}
}
...
...
server/routes_create_test.go
View file @
41be2809
...
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
...
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
checkFileExists
(
t
,
filepath
.
Join
(
p
,
"blobs"
,
"*"
),
[]
string
{
checkFileExists
(
t
,
filepath
.
Join
(
p
,
"blobs"
,
"*"
),
[]
string
{
filepath
.
Join
(
p
,
"blobs"
,
"sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
68b0323b2f21572bc09ba07554b16b379a5713ee48ef8c25a7661a1f71cfce7
7"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce5
7c
d
ac
893db8139
"
),
filepath
.
Join
(
p
,
"blobs"
,
"sha256-
eb72fb7c550ee1f1dec4039bd65382acecf5f7536a30fb
7c
c
ac
e39a8d0cb590b
"
),
})
})
})
})
...
...
template/template.go
View file @
41be2809
...
@@ -143,11 +143,14 @@ func (t *Template) Vars() []string {
...
@@ -143,11 +143,14 @@ func (t *Template) Vars() []string {
type
Values
struct
{
type
Values
struct
{
Messages
[]
api
.
Message
Messages
[]
api
.
Message
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy
bool
}
}
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
func
(
t
*
Template
)
Execute
(
w
io
.
Writer
,
v
Values
)
error
{
system
,
collated
:=
collate
(
v
.
Messages
)
system
,
collated
:=
collate
(
v
.
Messages
)
if
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
if
!
v
.
forceLegacy
&&
slices
.
Contains
(
t
.
Vars
(),
"messages"
)
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
return
t
.
Template
.
Execute
(
w
,
map
[
string
]
any
{
"System"
:
system
,
"System"
:
system
,
"Messages"
:
collated
,
"Messages"
:
collated
,
...
@@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var
b
bytes
.
Buffer
var
b
bytes
.
Buffer
var
prompt
,
response
string
var
prompt
,
response
string
for
i
,
m
:=
range
collated
{
for
i
,
m
:=
range
collated
{
if
m
.
Role
==
"user"
{
switch
m
.
Role
{
case
"user"
:
prompt
=
m
.
Content
prompt
=
m
.
Content
}
else
{
if
i
!=
0
{
system
=
""
}
case
"assistant"
:
response
=
m
.
Content
response
=
m
.
Content
}
}
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
i
!=
len
(
collated
)
-
1
&&
prompt
!=
""
&&
response
!=
""
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
if
err
:=
t
.
Template
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"System"
:
system
,
"Prompt"
:
prompt
,
"Prompt"
:
prompt
,
"Response"
:
response
,
"Response"
:
response
,
});
err
!=
nil
{
});
err
!=
nil
{
...
@@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
...
@@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
}
var
cut
bool
var
cut
bool
tree
:=
t
.
Template
.
Copy
()
nodes
:=
deleteNode
(
t
.
Template
.
Root
.
Copy
(),
func
(
n
parse
.
Node
)
bool
{
// for the last message, cut everything after "{{ .Response }}"
switch
t
:=
n
.
(
type
)
{
tree
.
Root
.
Nodes
=
slices
.
DeleteFunc
(
tree
.
Root
.
Nodes
,
func
(
n
parse
.
Node
)
bool
{
case
*
parse
.
ActionNode
:
if
slices
.
Contains
(
parseNode
(
n
),
"Response"
)
{
case
*
parse
.
FieldNode
:
cut
=
true
if
slices
.
Contains
(
t
.
Ident
,
"Response"
)
{
cut
=
true
}
}
}
return
cut
return
cut
})
})
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
tree
:=
parse
.
Tree
{
Root
:
nodes
.
(
*
parse
.
ListNode
)}
"System"
:
system
,
if
err
:=
template
.
Must
(
template
.
New
(
""
)
.
AddParseTree
(
""
,
&
tree
))
.
Execute
(
&
b
,
map
[
string
]
any
{
"System"
:
""
,
"Prompt"
:
prompt
,
"Prompt"
:
prompt
,
});
err
!=
nil
{
});
err
!=
nil
{
return
err
return
err
...
@@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string {
...
@@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string {
return
nil
return
nil
}
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func
deleteNode
(
n
parse
.
Node
,
fn
func
(
parse
.
Node
)
bool
)
parse
.
Node
{
var
walk
func
(
n
parse
.
Node
)
parse
.
Node
walk
=
func
(
n
parse
.
Node
)
parse
.
Node
{
if
fn
(
n
)
{
return
nil
}
switch
t
:=
n
.
(
type
)
{
case
*
parse
.
ListNode
:
var
nodes
[]
parse
.
Node
for
_
,
c
:=
range
t
.
Nodes
{
if
n
:=
walk
(
c
);
n
!=
nil
{
nodes
=
append
(
nodes
,
n
)
}
}
t
.
Nodes
=
nodes
return
t
case
*
parse
.
IfNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
WithNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
RangeNode
:
t
.
BranchNode
=
*
(
walk
(
&
t
.
BranchNode
)
.
(
*
parse
.
BranchNode
))
case
*
parse
.
BranchNode
:
t
.
List
=
walk
(
t
.
List
)
.
(
*
parse
.
ListNode
)
if
t
.
ElseList
!=
nil
{
t
.
ElseList
=
walk
(
t
.
ElseList
)
.
(
*
parse
.
ListNode
)
}
case
*
parse
.
ActionNode
:
n
:=
walk
(
t
.
Pipe
)
if
n
==
nil
{
return
nil
}
t
.
Pipe
=
n
.
(
*
parse
.
PipeNode
)
case
*
parse
.
PipeNode
:
var
commands
[]
*
parse
.
CommandNode
for
_
,
c
:=
range
t
.
Cmds
{
var
args
[]
parse
.
Node
for
_
,
a
:=
range
c
.
Args
{
if
n
:=
walk
(
a
);
n
!=
nil
{
args
=
append
(
args
,
n
)
}
}
if
len
(
args
)
==
0
{
return
nil
}
c
.
Args
=
args
commands
=
append
(
commands
,
c
)
}
if
len
(
commands
)
==
0
{
return
nil
}
t
.
Cmds
=
commands
}
return
n
}
return
walk
(
n
)
}
template/template_test.go
View file @
41be2809
...
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
...
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
}
}
for
n
,
tt
:=
range
cases
{
for
n
,
tt
:=
range
cases
{
var
actual
bytes
.
Buffer
t
.
Run
(
n
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
n
,
func
(
t
*
testing
.
T
)
{
var
actual
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
actual
,
Values
{
Messages
:
tt
});
err
!=
nil
{
if
err
:=
tmpl
.
Execute
(
&
actual
,
Values
{
Messages
:
tt
});
err
!=
nil
{
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
...
@@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) {
...
@@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) {
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
t
.
Run
(
"legacy"
,
func
(
t
*
testing
.
T
)
{
var
legacy
bytes
.
Buffer
if
err
:=
tmpl
.
Execute
(
&
legacy
,
Values
{
Messages
:
tt
,
forceLegacy
:
true
});
err
!=
nil
{
t
.
Fatal
(
err
)
}
legacyBytes
:=
legacy
.
Bytes
()
if
slices
.
Contains
([]
string
{
"chatqa.gotmpl"
,
"openchat.gotmpl"
,
"vicuna.gotmpl"
},
match
)
&&
legacyBytes
[
len
(
legacyBytes
)
-
1
]
==
' '
{
t
.
Log
(
"removing trailing space from legacy output"
)
legacyBytes
=
legacyBytes
[
:
len
(
legacyBytes
)
-
1
]
}
else
if
slices
.
Contains
([]
string
{
"codellama-70b-instruct.gotmpl"
,
"llama2-chat.gotmpl"
,
"mistral-instruct.gotmpl"
},
match
)
{
t
.
Skip
(
"legacy outputs cannot be compared to messages outputs"
)
}
if
diff
:=
cmp
.
Diff
(
legacyBytes
,
actual
.
Bytes
());
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
}
}
})
})
}
}
...
@@ -136,6 +155,21 @@ func TestParse(t *testing.T) {
...
@@ -136,6 +155,21 @@ func TestParse(t *testing.T) {
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}"
,
[]
string
{
"prompt"
,
"response"
,
"system"
,
"tools"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
"{{ range .Messages }}{{ if eq .Role
\"
system
\"
}}SYSTEM: {{ .Content }}{{ else if eq .Role
\"
user
\"
}}USER: {{ .Content }}{{ else if eq .Role
\"
assistant
\"
}}ASSISTANT: {{ .Content }}{{ end }}{{ end }}"
,
[]
string
{
"content"
,
"messages"
,
"role"
}},
{
`{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`
,
[]
string
{
"content"
,
"messages"
,
"prompt"
,
"response"
,
"role"
,
"system"
}},
}
}
for
_
,
tt
:=
range
cases
{
for
_
,
tt
:=
range
cases
{
...
@@ -145,9 +179,8 @@ func TestParse(t *testing.T) {
...
@@ -145,9 +179,8 @@ func TestParse(t *testing.T) {
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
vars
:=
tmpl
.
Vars
()
if
diff
:=
cmp
.
Diff
(
tmpl
.
Vars
(),
tt
.
vars
);
diff
!=
""
{
if
!
slices
.
Equal
(
tt
.
vars
,
vars
)
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
t
.
Errorf
(
"expected %v, got %v"
,
tt
.
vars
,
vars
)
}
}
})
})
}
}
...
@@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"no response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{
"messages"
,
`{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
},
{{- end }}`
},
...
@@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) {
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"response"
,
`[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`
},
{
"messages"
,
`
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq
(len (slice $.Messages
$index
)) 1
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- if eq .Role "user" }}[INST] {{ if and (eq $index
0
) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}
{{- end }}`
},
{{- end }}`
},
...
@@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) {
...
@@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) {
{
Role
:
"user"
,
Content
:
"What is your name?"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
},
},
`[INST]
Hello friend![/INST] Hello human![INST]
You are a helpful assistant!
`[INST] You are a helpful assistant!
What is your name?[/INST] `
,
Hello friend![/INST] Hello human![INST]
What is your name?[/INST] `
,
},
},
{
{
"chatml"
,
"chatml"
,
...
@@ -221,7 +254,7 @@ What is your name?[/INST] `,
...
@@ -221,7 +254,7 @@ What is your name?[/INST] `,
`
},
`
},
{
"messages"
,
`
{
"messages"
,
`
{{- range $index, $_ := .Messages }}
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq
(len (slice $.Messages
$index
)) 1
) $.System }}<|im_start|>system
{{- if and (eq .Role "user") (eq $index
0
) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{ .Content }}<|im_end|>{{ "\n" }}
...
@@ -236,12 +269,12 @@ What is your name?[/INST] `,
...
@@ -236,12 +269,12 @@ What is your name?[/INST] `,
{
Role
:
"user"
,
Content
:
"What is your name?"
},
{
Role
:
"user"
,
Content
:
"What is your name?"
},
},
},
},
},
`<|im_start|>user
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Hello friend!<|im_end|>
Hello friend!<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
Hello human!<|im_end|>
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
<|im_start|>user
What is your name?<|im_end|>
What is your name?<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
...
@@ -300,8 +333,8 @@ Answer: `,
...
@@ -300,8 +333,8 @@ Answer: `,
t
.
Fatal
(
err
)
t
.
Fatal
(
err
)
}
}
if
b
.
String
()
!=
tt
.
expected
{
if
diff
:=
cmp
.
Diff
(
b
.
String
()
,
tt
.
expected
);
diff
!=
""
{
t
.
Errorf
(
"
expected
\n
%s,
\n
got
\n
%s"
,
tt
.
expected
,
b
.
String
()
)
t
.
Errorf
(
"
mismatch (-got +want):
\n
%s"
,
diff
)
}
}
})
})
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment