Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
5d3f314b
Commit
5d3f314b
authored
Sep 03, 2023
by
Michael Yang
Browse files
remove marshalPrompt which is no longer needed
parent
adaa1308
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
42 deletions
+19
-42
llm/ggml_llama.go
llm/ggml_llama.go
+19
-42
No files found.
llm/ggml_llama.go
View file @
5d3f314b
...
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
...
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
runner
.
Path
,
runner
.
Path
,
append
(
params
,
"--port"
,
strconv
.
Itoa
(
port
))
...
,
append
(
params
,
"--port"
,
strconv
.
Itoa
(
port
))
...
,
)
)
var
stderr
bytes
.
Buffe
r
cmd
.
Stdout
=
os
.
Stder
r
cmd
.
Stderr
=
&
s
tderr
cmd
.
Stderr
=
os
.
S
tderr
llm
:=
&
llama
{
Options
:
opts
,
Running
:
Running
{
Port
:
port
,
Cmd
:
cmd
,
Cancel
:
cancel
}}
llm
:=
&
llama
{
Options
:
opts
,
Running
:
Running
{
Port
:
port
,
Cmd
:
cmd
,
Cancel
:
cancel
}}
...
@@ -437,15 +437,19 @@ type PredictRequest struct {
...
@@ -437,15 +437,19 @@ type PredictRequest struct {
Stop
[]
string
`json:"stop,omitempty"`
Stop
[]
string
`json:"stop,omitempty"`
}
}
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
predictCtx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
prevContext
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
// we need to find the trimmed prompt context before predicting so that we can return it to the client
prevConvo
,
err
:=
llm
.
Decode
(
ctx
,
prevContext
)
trimmedPrompt
,
err
:=
llm
.
marshalPrompt
(
ctx
,
predictCtx
,
prompt
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshaling prompt: %v"
,
err
)
return
err
}
}
var
nextContext
strings
.
Builder
nextContext
.
WriteString
(
prevConvo
)
nextContext
.
WriteString
(
prompt
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
predReq
:=
PredictRequest
{
predReq
:=
PredictRequest
{
Prompt
:
trimmedPrompt
,
Prompt
:
nextContext
.
String
()
,
Stream
:
true
,
Stream
:
true
,
NPredict
:
llm
.
NumPredict
,
NPredict
:
llm
.
NumPredict
,
NKeep
:
llm
.
NumKeep
,
NKeep
:
llm
.
NumKeep
,
...
@@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
}
}
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
genCtx
:=
trimmedPrompt
// start with the trimmed prompt
for
scanner
.
Scan
()
{
for
scanner
.
Scan
()
{
select
{
select
{
case
<-
ctx
.
Done
()
:
case
<-
ctx
.
Done
()
:
...
@@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
}
}
if
complete
.
Timings
.
PredictedMS
>
0
{
if
complete
.
Timings
.
PredictedMS
>
0
{
genCtx
+=
complete
.
Content
nextContext
.
WriteString
(
complete
.
Content
)
embd
,
err
:=
llm
.
Encode
(
ctx
,
genCtx
)
embd
,
err
:=
llm
.
Encode
(
ctx
,
nextContext
.
String
()
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encoding context: %v"
,
err
)
return
fmt
.
Errorf
(
"encoding context: %v"
,
err
)
}
}
fn
(
api
.
GenerateResponse
{
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Done
:
true
,
Context
:
embd
,
Context
:
embd
,
...
@@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
return
nil
return
nil
}
}
var
p
red
Prediction
var
p
Prediction
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
p
red
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
p
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
}
}
genCtx
+=
pred
.
Content
fn
(
api
.
GenerateResponse
{
Response
:
pred
.
Content
})
fn
(
api
.
GenerateResponse
{
Response
:
p
.
Content
})
nextContext
.
WriteString
(
p
.
Content
)
}
}
}
}
}
}
...
@@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
return
nil
return
nil
}
}
func
(
llm
*
llama
)
marshalPrompt
(
ctx
context
.
Context
,
pCtx
[]
int
,
prompt
string
)
(
string
,
error
)
{
pEncode
,
err
:=
llm
.
Encode
(
ctx
,
prompt
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"encoding prompt context: %w"
,
err
)
}
tokens
:=
append
(
pCtx
,
pEncode
...
)
if
llm
.
NumKeep
<
0
{
llm
.
NumKeep
=
len
(
tokens
)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if
llm
.
NumCtx
-
4
<
llm
.
NumKeep
{
llm
.
NumKeep
=
llm
.
NumCtx
-
4
}
if
len
(
tokens
)
>=
llm
.
NumCtx
{
// truncate input
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
tokens
[
:
llm
.
NumKeep
]
erasedBlocks
:=
(
len
(
tokens
)
-
llm
.
NumKeep
-
numLeft
-
1
)
/
numLeft
truncated
=
append
(
truncated
,
tokens
[
llm
.
NumKeep
+
erasedBlocks
*
numLeft
:
]
...
)
tokens
=
truncated
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
))
}
return
llm
.
Decode
(
ctx
,
tokens
)
}
type
TokenizeRequest
struct
{
type
TokenizeRequest
struct
{
Content
string
`json:"content"`
Content
string
`json:"content"`
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment