Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
7b5aefb4
Unverified
Commit
7b5aefb4
authored
Sep 05, 2023
by
Michael Yang
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #462 from jmorganca/mxyng/rm-marshal-prompt
remove marshalPrompt which is no longer needed
parents
7fa6e516
7efbc843
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
82 deletions
+38
-82
llm/ggml_llama.go
llm/ggml_llama.go
+36
-81
server/routes.go
server/routes.go
+2
-1
No files found.
llm/ggml_llama.go
View file @
7b5aefb4
...
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
...
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
runner
.
Path
,
runner
.
Path
,
append
(
params
,
"--port"
,
strconv
.
Itoa
(
port
))
...
,
append
(
params
,
"--port"
,
strconv
.
Itoa
(
port
))
...
,
)
)
var
stderr
bytes
.
Buffe
r
cmd
.
Stdout
=
os
.
Stder
r
cmd
.
Stderr
=
&
s
tderr
cmd
.
Stderr
=
os
.
S
tderr
llm
:=
&
llama
{
Options
:
opts
,
Running
:
Running
{
Port
:
port
,
Cmd
:
cmd
,
Cancel
:
cancel
}}
llm
:=
&
llama
{
Options
:
opts
,
Running
:
Running
{
Port
:
port
,
Cmd
:
cmd
,
Cancel
:
cancel
}}
...
@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
...
@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
llm
.
Options
=
opts
llm
.
Options
=
opts
}
}
type
Prediction
struct
{
Content
string
`json:"content"`
Stop
bool
`json:"stop"`
}
type
GenerationSettings
struct
{
type
GenerationSettings
struct
{
FrequencyPenalty
float64
`json:"frequency_penalty"`
FrequencyPenalty
float64
`json:"frequency_penalty"`
IgnoreEOS
bool
`json:"ignore_eos"`
IgnoreEOS
bool
`json:"ignore_eos"`
...
@@ -385,31 +380,19 @@ type GenerationSettings struct {
...
@@ -385,31 +380,19 @@ type GenerationSettings struct {
}
}
type
Timings
struct
{
type
Timings
struct
{
PredictedMS
float64
`json:"predicted_ms"`
PredictedN
int
`json:"predicted_n"`
PredictedN
int
`json:"predicted_n"`
PredictedMS
float64
`json:"predicted_ms"`
PredictedPerSecond
float64
`json:"predicted_per_second"`
PromptN
int
`json:"prompt_n"`
PredictedPerTokenMS
float64
`json:"predicted_per_token_ms"`
PromptMS
float64
`json:"prompt_ms"`
PromptMS
float64
`json:"prompt_ms"`
PromptN
int
`json:"prompt_n"`
PromptPerSecond
float64
`json:"prompt_per_second"`
PromptPerTokenMS
float64
`json:"prompt_per_token_ms"`
}
}
type
PredictComplete
struct
{
type
Prediction
struct
{
Content
string
`json:"content"`
Content
string
`json:"content"`
GenerationSettings
GenerationSettings
`json:"generation_settings"`
Model
string
`json:"model"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Prompt
string
`json:"prompt"`
Stop
bool
`json:"stop"`
Stop
bool
`json:"stop"`
StoppedEOS
bool
`json:"stopped_eos"`
Timings
`json:"timings"`
StoppedLimit
bool
`json:"stopped_limit"`
StoppedWord
bool
`json:"stopped_word"`
StoppingWord
string
`json:"stopping_word"`
Timings
Timings
`json:"timings"`
TokensCached
int
`json:"tokens_cached"`
TokensEvaluated
int
`json:"tokens_evaluated"`
TokensPredicted
int
`json:"tokens_predicted"`
Truncated
bool
`json:"truncated"`
}
}
type
PredictRequest
struct
{
type
PredictRequest
struct
{
...
@@ -437,15 +420,19 @@ type PredictRequest struct {
...
@@ -437,15 +420,19 @@ type PredictRequest struct {
Stop
[]
string
`json:"stop,omitempty"`
Stop
[]
string
`json:"stop,omitempty"`
}
}
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
predictCtx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
func
(
llm
*
llama
)
Predict
(
ctx
context
.
Context
,
prevContext
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
// we need to find the trimmed prompt context before predicting so that we can return it to the client
prevConvo
,
err
:=
llm
.
Decode
(
ctx
,
prevContext
)
trimmedPrompt
,
err
:=
llm
.
marshalPrompt
(
ctx
,
predictCtx
,
prompt
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshaling prompt: %v"
,
err
)
return
err
}
}
var
nextContext
strings
.
Builder
nextContext
.
WriteString
(
prevConvo
)
nextContext
.
WriteString
(
prompt
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
predReq
:=
PredictRequest
{
predReq
:=
PredictRequest
{
Prompt
:
trimmedPrompt
,
Prompt
:
nextContext
.
String
()
,
Stream
:
true
,
Stream
:
true
,
NPredict
:
llm
.
NumPredict
,
NPredict
:
llm
.
NumPredict
,
NKeep
:
llm
.
NumKeep
,
NKeep
:
llm
.
NumKeep
,
...
@@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
}
}
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
genCtx
:=
trimmedPrompt
// start with the trimmed prompt
for
scanner
.
Scan
()
{
for
scanner
.
Scan
()
{
select
{
select
{
case
<-
ctx
.
Done
()
:
case
<-
ctx
.
Done
()
:
...
@@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
// Read data from the server-side event stream
// Read data from the server-side event stream
if
strings
.
HasPrefix
(
line
,
"data: "
)
{
if
strings
.
HasPrefix
(
line
,
"data: "
)
{
evt
:=
line
[
6
:
]
evt
:=
line
[
6
:
]
var
complete
PredictComplete
var
p
Prediction
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
complete
);
err
!=
nil
{
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
p
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm
complete
response: %v"
,
err
)
return
fmt
.
Errorf
(
"error unmarshaling llm
prediction
response: %v"
,
err
)
}
}
if
complete
.
Timings
.
PredictedMS
>
0
{
fn
(
api
.
GenerateResponse
{
Response
:
p
.
Content
})
genCtx
+=
complete
.
Content
nextContext
.
WriteString
(
p
.
Content
)
embd
,
err
:=
llm
.
Encode
(
ctx
,
genCtx
)
if
p
.
Stop
{
embd
,
err
:=
llm
.
Encode
(
ctx
,
nextContext
.
String
())
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"encoding context: %v"
,
err
)
return
fmt
.
Errorf
(
"encoding context: %v"
,
err
)
}
}
fn
(
api
.
GenerateResponse
{
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Done
:
true
,
Context
:
embd
,
Context
:
embd
,
PromptEvalCount
:
int
(
complete
.
Timings
.
PromptN
)
,
PromptEvalCount
:
p
.
PromptN
,
PromptEvalDuration
:
parseDurationMs
(
float64
(
complete
.
Timings
.
PromptMS
)
)
,
PromptEvalDuration
:
parseDurationMs
(
p
.
PromptMS
),
EvalCount
:
int
(
complete
.
Timings
.
PredictedN
)
,
EvalCount
:
p
.
PredictedN
,
EvalDuration
:
parseDurationMs
(
float64
(
complete
.
Timings
.
PredictedMS
)
)
,
EvalDuration
:
parseDurationMs
(
p
.
PredictedMS
),
})
})
return
nil
}
var
pred
Prediction
return
nil
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
pred
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
}
}
genCtx
+=
pred
.
Content
fn
(
api
.
GenerateResponse
{
Response
:
pred
.
Content
})
}
}
}
}
}
}
...
@@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
...
@@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
return
nil
return
nil
}
}
func
(
llm
*
llama
)
marshalPrompt
(
ctx
context
.
Context
,
pCtx
[]
int
,
prompt
string
)
(
string
,
error
)
{
pEncode
,
err
:=
llm
.
Encode
(
ctx
,
prompt
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"encoding prompt context: %w"
,
err
)
}
tokens
:=
append
(
pCtx
,
pEncode
...
)
if
llm
.
NumKeep
<
0
{
llm
.
NumKeep
=
len
(
tokens
)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if
llm
.
NumCtx
-
4
<
llm
.
NumKeep
{
llm
.
NumKeep
=
llm
.
NumCtx
-
4
}
if
len
(
tokens
)
>=
llm
.
NumCtx
{
// truncate input
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
tokens
[
:
llm
.
NumKeep
]
erasedBlocks
:=
(
len
(
tokens
)
-
llm
.
NumKeep
-
numLeft
-
1
)
/
numLeft
truncated
=
append
(
truncated
,
tokens
[
llm
.
NumKeep
+
erasedBlocks
*
numLeft
:
]
...
)
tokens
=
truncated
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
))
}
return
llm
.
Decode
(
ctx
,
tokens
)
}
type
TokenizeRequest
struct
{
type
TokenizeRequest
struct
{
Content
string
`json:"content"`
Content
string
`json:"content"`
}
}
...
...
server/routes.go
View file @
7b5aefb4
...
@@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses
...
@@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
tokensNoSystem
,
err
:=
llmModel
.
Encode
(
ctx
,
promptNoSystem
)
tokensNoSystem
,
err
:=
llmModel
.
Encode
(
ctx
,
promptNoSystem
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
opts
.
NumKeep
=
len
(
tokensWithSystem
)
-
len
(
tokensNoSystem
)
+
1
opts
.
NumKeep
=
len
(
tokensWithSystem
)
-
len
(
tokensNoSystem
)
llmModel
.
SetOptions
(
opts
)
llmModel
.
SetOptions
(
opts
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment