Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
08b0e04f
Unverified
Commit
08b0e04f
authored
Oct 17, 2023
by
Michael Yang
Committed by
GitHub
Oct 17, 2023
Browse files
Merge pull request #813 from jmorganca/mxyng/llama
refactor llm/llama.go
parents
f3648fd2
b36b0b71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
90 deletions
+38
-90
llm/llama.go
llm/llama.go
+38
-90
No files found.
llm/llama.go
View file @
08b0e04f
...
@@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
...
@@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
llm
.
Options
=
opts
llm
.
Options
=
opts
}
}
type
GenerationSettings
struct
{
type
prediction
struct
{
FrequencyPenalty
float64
`json:"frequency_penalty"`
IgnoreEOS
bool
`json:"ignore_eos"`
LogitBias
[]
interface
{}
`json:"logit_bias"`
Mirostat
int
`json:"mirostat"`
MirostatEta
float64
`json:"mirostat_eta"`
MirostatTau
float64
`json:"mirostat_tau"`
Model
string
`json:"model"`
NCtx
int
`json:"n_ctx"`
NKeep
int
`json:"n_keep"`
NPredict
int
`json:"n_predict"`
NProbs
int
`json:"n_probs"`
PenalizeNl
bool
`json:"penalize_nl"`
PresencePenalty
float64
`json:"presence_penalty"`
RepeatLastN
int
`json:"repeat_last_n"`
RepeatPenalty
float64
`json:"repeat_penalty"`
Seed
uint32
`json:"seed"`
Stop
[]
string
`json:"stop"`
Stream
bool
`json:"stream"`
Temp
float64
`json:"temp"`
TfsZ
float64
`json:"tfs_z"`
TopK
int
`json:"top_k"`
TopP
float64
`json:"top_p"`
TypicalP
float64
`json:"typical_p"`
}
type
Timings
struct
{
PredictedN
int
`json:"predicted_n"`
PredictedMS
float64
`json:"predicted_ms"`
PromptN
int
`json:"prompt_n"`
PromptMS
float64
`json:"prompt_ms"`
}
type
Prediction
struct
{
Content
string
`json:"content"`
Content
string
`json:"content"`
Model
string
`json:"model"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Prompt
string
`json:"prompt"`
Stop
bool
`json:"stop"`
Stop
bool
`json:"stop"`
Timings
`json:"timings"`
Timings
struct
{
}
PredictedN
int
`json:"predicted_n"`
PredictedMS
float64
`json:"predicted_ms"`
type
PredictRequest
struct
{
PromptN
int
`json:"prompt_n"`
Prompt
string
`json:"prompt"`
PromptMS
float64
`json:"prompt_ms"`
Stream
bool
`json:"stream"`
}
NPredict
int
`json:"n_predict"`
NKeep
int
`json:"n_keep"`
Temperature
float32
`json:"temperature"`
TopK
int
`json:"top_k"`
TopP
float32
`json:"top_p"`
TfsZ
float32
`json:"tfs_z"`
TypicalP
float32
`json:"typical_p"`
RepeatLastN
int
`json:"repeat_last_n"`
RepeatPenalty
float32
`json:"repeat_penalty"`
PresencePenalty
float32
`json:"presence_penalty"`
FrequencyPenalty
float32
`json:"frequency_penalty"`
Mirostat
int
`json:"mirostat"`
MirostatTau
float32
`json:"mirostat_tau"`
MirostatEta
float32
`json:"mirostat_eta"`
PenalizeNl
bool
`json:"penalize_nl"`
Seed
int
`json:"seed"`
Stop
[]
string
`json:"stop,omitempty"`
}
}
const
maxBufferSize
=
512
*
format
.
KiloByte
const
maxBufferSize
=
512
*
format
.
KiloByte
...
@@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
...
@@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
nextContext
.
WriteString
(
prevConvo
)
nextContext
.
WriteString
(
prevConvo
)
nextContext
.
WriteString
(
prompt
)
nextContext
.
WriteString
(
prompt
)
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
request
:=
map
[
string
]
any
{
predReq
:=
PredictRequest
{
"prompt"
:
nextContext
.
String
(),
Prompt
:
nextContext
.
String
(),
"stream"
:
true
,
Stream
:
true
,
"n_predict"
:
llm
.
NumPredict
,
NPredict
:
llm
.
NumPredict
,
"n_keep"
:
llm
.
NumKeep
,
NKeep
:
llm
.
NumKeep
,
"temperature"
:
llm
.
Temperature
,
Temperature
:
llm
.
Temperature
,
"top_k"
:
llm
.
TopK
,
TopK
:
llm
.
TopK
,
"top_p"
:
llm
.
TopP
,
TopP
:
llm
.
TopP
,
"tfs_z"
:
llm
.
TFSZ
,
TfsZ
:
llm
.
TFSZ
,
"typical_p"
:
llm
.
TypicalP
,
TypicalP
:
llm
.
TypicalP
,
"repeat_last_n"
:
llm
.
RepeatLastN
,
RepeatLastN
:
llm
.
RepeatLastN
,
"repeat_penalty"
:
llm
.
RepeatPenalty
,
RepeatPenalty
:
llm
.
RepeatPenalty
,
"presence_penalty"
:
llm
.
PresencePenalty
,
PresencePenalty
:
llm
.
PresencePenalty
,
"frequency_penalty"
:
llm
.
FrequencyPenalty
,
FrequencyPenalty
:
llm
.
FrequencyPenalty
,
"mirostat"
:
llm
.
Mirostat
,
Mirostat
:
llm
.
Mirostat
,
"mirostat_tau"
:
llm
.
MirostatTau
,
MirostatTau
:
llm
.
MirostatTau
,
"mirostat_eta"
:
llm
.
MirostatEta
,
MirostatEta
:
llm
.
MirostatEta
,
"penalize_nl"
:
llm
.
PenalizeNewline
,
PenalizeNl
:
llm
.
PenalizeNewline
,
"seed"
:
llm
.
Seed
,
Seed
:
llm
.
Seed
,
"stop"
:
llm
.
Stop
,
Stop
:
llm
.
Stop
,
}
}
// Handling JSON marshaling with special characters unescaped.
// Handling JSON marshaling with special characters unescaped.
...
@@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
...
@@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
enc
:=
json
.
NewEncoder
(
buffer
)
enc
:=
json
.
NewEncoder
(
buffer
)
enc
.
SetEscapeHTML
(
false
)
enc
.
SetEscapeHTML
(
false
)
if
err
:=
enc
.
Encode
(
predReq
);
err
!=
nil
{
if
err
:=
enc
.
Encode
(
request
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to marshal data: %v"
,
err
)
return
fmt
.
Errorf
(
"failed to marshal data: %v"
,
err
)
}
}
endpoint
:=
fmt
.
Sprintf
(
"http://127.0.0.1:%d/completion"
,
llm
.
Port
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
endpoint
,
buffer
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
endpoint
,
buffer
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"error creating POST request: %v"
,
err
)
return
fmt
.
Errorf
(
"error creating POST request: %v"
,
err
)
...
@@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
...
@@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
// This handles the request cancellation
// This handles the request cancellation
return
ctx
.
Err
()
return
ctx
.
Err
()
default
:
default
:
line
:=
scanner
.
Text
()
line
:=
scanner
.
Bytes
()
if
line
==
""
{
if
len
(
line
)
==
0
{
continue
continue
}
}
// Read data from the server-side event stream
if
evt
,
ok
:=
bytes
.
CutPrefix
(
line
,
[]
byte
(
"data: "
));
ok
{
if
strings
.
HasPrefix
(
line
,
"data: "
)
{
var
p
prediction
evt
:=
line
[
6
:
]
if
err
:=
json
.
Unmarshal
(
evt
,
&
p
);
err
!=
nil
{
var
p
Prediction
if
err
:=
json
.
Unmarshal
([]
byte
(
evt
),
&
p
);
err
!=
nil
{
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
return
fmt
.
Errorf
(
"error unmarshaling llm prediction response: %v"
,
err
)
}
}
...
@@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
...
@@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
fn
(
api
.
GenerateResponse
{
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Done
:
true
,
Context
:
embd
,
Context
:
embd
,
PromptEvalCount
:
p
.
PromptN
,
PromptEvalCount
:
p
.
Timings
.
PromptN
,
PromptEvalDuration
:
parseDurationMs
(
p
.
PromptMS
),
PromptEvalDuration
:
parseDurationMs
(
p
.
Timings
.
PromptMS
),
EvalCount
:
p
.
PredictedN
,
EvalCount
:
p
.
Timings
.
PredictedN
,
EvalDuration
:
parseDurationMs
(
p
.
PredictedMS
),
EvalDuration
:
parseDurationMs
(
p
.
Timings
.
PredictedMS
),
})
})
return
nil
return
nil
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment