Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
7b5aefb4
Unverified
Commit
7b5aefb4
authored
Sep 05, 2023
by
Michael Yang
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #462 from jmorganca/mxyng/rm-marshal-prompt
remove marshalPrompt which is no longer needed
parents
7fa6e516
7efbc843
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
82 deletions
+38
-82
llm/ggml_llama.go
llm/ggml_llama.go
+36
-81
server/routes.go
server/routes.go
+2
-1
No files found.
llm/ggml_llama.go
View file @
7b5aefb4
...
...
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
runner
.
Path
,
append
(
params
,
"--port"
,
strconv
.
Itoa
(
port
))
...
,
)
var
stderr
bytes
.
Buffe
r
cmd
.
Stderr
=
&
s
tderr
cmd
.
Stdout
=
os
.
Stder
r
cmd
.
Stderr
=
os
.
S
tderr
llm
:=
&
llama
{
Options
:
opts
,
Running
:
Running
{
Port
:
port
,
Cmd
:
cmd
,
Cancel
:
cancel
}}
...
...
@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
llm
.
Options
=
opts
}
type
Prediction
struct
{
Content
string
`json:"content"`
Stop
bool
`json:"stop"`
}
type
GenerationSettings
struct
{
FrequencyPenalty
float64
`json:"frequency_penalty"`
IgnoreEOS
bool
`json:"ignore_eos"`
...
...
@@ -385,31 +380,19 @@ type GenerationSettings struct {
}
type
Timings
struct
{
PredictedMS
float64
`json:"predicted_ms"`
PredictedN
int
`json:"predicted_n"`
PredictedPerSecond
float64
`json:"predicted_per_second"`
PredictedPerTokenMS
float64
`json:"predicted_per_token_ms"`
PromptMS
float64
`json:"prompt_ms"`
PromptN
int
`json:"prompt_n"`
PromptPerSecond
float64
`json:"prompt_per_second"`
PromptPerTokenMS
float64
`json:"prompt_per_token_ms"`
PredictedN
int
`json:"predicted_n"`
PredictedMS
float64
`json:"predicted_ms"`
PromptN
int
`json:"prompt_n"`
PromptMS
float64
`json:"prompt_ms"`
}
type
PredictComplete
struct
{
Content
string
`json:"content"`
GenerationSettings
GenerationSettings
`json:"generation_settings"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Stop
bool
`json:"stop"`
StoppedEOS
bool
`json:"stopped_eos"`
StoppedLimit
bool
`json:"stopped_limit"`
StoppedWord
bool
`json:"stopped_word"`
StoppingWord
string
`json:"stopping_word"`
Timings
Timings
`json:"timings"`
TokensCached
int
`json:"tokens_cached"`
TokensEvaluated
int
`json:"tokens_evaluated"`
TokensPredicted
int
`json:"tokens_predicted"`
Truncated
bool
`json:"truncated"`
type
Prediction
struct
{
Content
string
`json:"content"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Stop
bool
`json:"stop"`
Timings
`json:"timings"`
}
type
PredictRequest
struct
{
...
...
@@ -437,15 +420,19 @@ type PredictRequest struct {
Stop
[]
string
`json:"stop,omitempty"`
}
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
predictCtx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
// we need to find the trimmed prompt context before predicting so that we can return it to the client
trimmedPrompt
,
err
:=
llm
.
marshalPrompt
(
ctx
,
predictCtx
,
prompt
)
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
prevContext
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
prevConvo
,
err
:=
llm
.
Decode
(
ctx
,
prevContext
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshaling prompt: %v"
,
err
)
return
err
}
var
nextContext
strings
.
Builder
nextContext
.
WriteString
(
prevConvo
)
nextContext
.
WriteString
(
prompt
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
predReq
:=
PredictRequest
{
Prompt
:
trimmedPrompt
,
Prompt
:
nextContext
.
String
()
,
Stream
:
true
,
NPredict
:
llm
.
NumPredict
,
NKeep
:
llm
.
NumKeep
,
...
...
@@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
}
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
genCtx
:=
trimmedPrompt
// start with the trimmed prompt
for
scanner
.
Scan
()
{
select
{
case
<-
ctx
.
Done
()
:
...
...
@@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
// Read data from the server-side event stream
if
strings
.
HasPrefix
(
line
,
"data: "
)
{
evt
:=
line
[
6
:
]
var
complete
PredictComplete
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
complete
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm
complete
response: %v"
,
err
)
var
p
Prediction
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
p
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm
prediction
response: %v"
,
err
)
}
if
complete
.
Timings
.
PredictedMS
>
0
{
genCtx
+=
complete
.
Content
embd
,
err
:=
llm
.
Encode
(
ctx
,
genCtx
)
fn
(
api
.
GenerateResponse
{
Response
:
p
.
Content
})
nextContext
.
WriteString
(
p
.
Content
)
if
p
.
Stop
{
embd
,
err
:=
llm
.
Encode
(
ctx
,
nextContext
.
String
())
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encoding context: %v"
,
err
)
}
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Context
:
embd
,
PromptEvalCount
:
int
(
complete
.
Timings
.
PromptN
)
,
PromptEvalDuration
:
parseDurationMs
(
float64
(
complete
.
Timings
.
PromptMS
)
)
,
EvalCount
:
int
(
complete
.
Timings
.
PredictedN
)
,
EvalDuration
:
parseDurationMs
(
float64
(
complete
.
Timings
.
PredictedMS
)
)
,
PromptEvalCount
:
p
.
PromptN
,
PromptEvalDuration
:
parseDurationMs
(
p
.
PromptMS
),
EvalCount
:
p
.
PredictedN
,
EvalDuration
:
parseDurationMs
(
p
.
PredictedMS
),
})
return
nil
}
var
pred
Prediction
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
pred
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
return
nil
}
genCtx
+=
pred
.
Content
fn
(
api
.
GenerateResponse
{
Response
:
pred
.
Content
})
}
}
}
...
...
@@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
return
nil
}
func
(
llm
*
llama
)
marshalPrompt
(
ctx
context
.
Context
,
pCtx
[]
int
,
prompt
string
)
(
string
,
error
)
{
pEncode
,
err
:=
llm
.
Encode
(
ctx
,
prompt
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"encoding prompt context: %w"
,
err
)
}
tokens
:=
append
(
pCtx
,
pEncode
...
)
if
llm
.
NumKeep
<
0
{
llm
.
NumKeep
=
len
(
tokens
)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if
llm
.
NumCtx
-
4
<
llm
.
NumKeep
{
llm
.
NumKeep
=
llm
.
NumCtx
-
4
}
if
len
(
tokens
)
>=
llm
.
NumCtx
{
// truncate input
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
tokens
[
:
llm
.
NumKeep
]
erasedBlocks
:=
(
len
(
tokens
)
-
llm
.
NumKeep
-
numLeft
-
1
)
/
numLeft
truncated
=
append
(
truncated
,
tokens
[
llm
.
NumKeep
+
erasedBlocks
*
numLeft
:
]
...
)
tokens
=
truncated
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
))
}
return
llm
.
Decode
(
ctx
,
tokens
)
}
type
TokenizeRequest
struct
{
Content
string
`json:"content"`
}
...
...
server/routes.go
View file @
7b5aefb4
...
...
@@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses
if
err
!=
nil
{
return
err
}
tokensNoSystem
,
err
:=
llmModel
.
Encode
(
ctx
,
promptNoSystem
)
if
err
!=
nil
{
return
err
}
opts
.
NumKeep
=
len
(
tokensWithSystem
)
-
len
(
tokensNoSystem
)
+
1
opts
.
NumKeep
=
len
(
tokensWithSystem
)
-
len
(
tokensNoSystem
)
llmModel
.
SetOptions
(
opts
)
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment