Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
db77dfe0
Unverified
Commit
db77dfe0
authored
Jul 27, 2023
by
Michael Yang
Committed by
GitHub
Jul 27, 2023
Browse files
Merge pull request #102 from jmorganca/session-id
Session
parents
36ad90e8
688661ab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
343 additions
and
212 deletions
+343
-212
api/types.go
api/types.go
+58
-2
cmd/cmd.go
cmd/cmd.go
+16
-8
llama/llama.go
llama/llama.go
+187
-86
llama/utils.go
llama/utils.go
+9
-98
server/routes.go
server/routes.go
+73
-18
No files found.
api/types.go
View file @
db77dfe0
package
api
import
(
"encoding/json"
"fmt"
"math"
"os"
"runtime"
"time"
...
...
@@ -28,6 +30,9 @@ func (e StatusError) Error() string {
}
type
GenerateRequest
struct
{
SessionID
int64
`json:"session_id"`
SessionDuration
Duration
`json:"session_duration,omitempty"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Context
[]
int
`json:"context,omitempty"`
...
...
@@ -81,6 +86,9 @@ type ListResponseModel struct {
}
type
GenerateResponse
struct
{
SessionID
int64
`json:"session_id"`
SessionExpiresAt
time
.
Time
`json:"session_expires_at"`
Model
string
`json:"model"`
CreatedAt
time
.
Time
`json:"created_at"`
Response
string
`json:"response,omitempty"`
...
...
@@ -89,6 +97,9 @@ type GenerateResponse struct {
Context
[]
int
`json:"context,omitempty"`
TotalDuration
time
.
Duration
`json:"total_duration,omitempty"`
LoadDuration
time
.
Duration
`json:"load_duration,omitempty"`
SampleCount
int
`json:"sample_count,omitempty"`
SampleDuration
time
.
Duration
`json:"sample_duration,omitempty"`
PromptEvalCount
int
`json:"prompt_eval_count,omitempty"`
PromptEvalDuration
time
.
Duration
`json:"prompt_eval_duration,omitempty"`
EvalCount
int
`json:"eval_count,omitempty"`
...
...
@@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
fmt
.
Fprintf
(
os
.
Stderr
,
"total duration: %v
\n
"
,
r
.
TotalDuration
)
}
if
r
.
LoadDuration
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"load duration: %v
\n
"
,
r
.
LoadDuration
)
}
if
r
.
SampleCount
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"sample count: %d token(s)
\n
"
,
r
.
SampleCount
)
}
if
r
.
SampleDuration
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"sample duration: %s
\n
"
,
r
.
SampleDuration
)
fmt
.
Fprintf
(
os
.
Stderr
,
"sample rate: %.2f tokens/s
\n
"
,
float64
(
r
.
SampleCount
)
/
r
.
SampleDuration
.
Seconds
())
}
if
r
.
PromptEvalCount
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"prompt eval count: %d token(s)
\n
"
,
r
.
PromptEvalCount
)
}
...
...
@@ -127,6 +151,7 @@ type Options struct {
// Model options
NumCtx
int
`json:"num_ctx,omitempty"`
NumKeep
int
`json:"num_keep,omitempty"`
NumBatch
int
`json:"num_batch,omitempty"`
NumGPU
int
`json:"num_gpu,omitempty"`
MainGPU
int
`json:"main_gpu,omitempty"`
...
...
@@ -151,6 +176,7 @@ type Options struct {
Mirostat
int
`json:"mirostat,omitempty"`
MirostatTau
float32
`json:"mirostat_tau,omitempty"`
MirostatEta
float32
`json:"mirostat_eta,omitempty"`
PenalizeNewline
bool
`json:"penalize_newline,omitempty"`
NumThread
int
`json:"num_thread,omitempty"`
}
...
...
@@ -162,14 +188,14 @@ func DefaultOptions() Options {
UseNUMA
:
false
,
NumCtx
:
2048
,
NumBatch
:
512
,
NumBatch
:
1024
,
NumGPU
:
1
,
LowVRAM
:
false
,
F16KV
:
true
,
UseMMap
:
true
,
UseMLock
:
false
,
RepeatLastN
:
512
,
RepeatLastN
:
64
,
RepeatPenalty
:
1.1
,
FrequencyPenalty
:
0.0
,
PresencePenalty
:
0.0
,
...
...
@@ -181,7 +207,37 @@ func DefaultOptions() Options {
Mirostat
:
0
,
MirostatTau
:
5.0
,
MirostatEta
:
0.1
,
PenalizeNewline
:
true
,
NumThread
:
runtime
.
NumCPU
(),
}
}
type
Duration
struct
{
time
.
Duration
}
func
(
d
*
Duration
)
UnmarshalJSON
(
b
[]
byte
)
(
err
error
)
{
var
v
any
if
err
:=
json
.
Unmarshal
(
b
,
&
v
);
err
!=
nil
{
return
err
}
d
.
Duration
=
5
*
time
.
Minute
switch
t
:=
v
.
(
type
)
{
case
float64
:
if
t
<
0
{
t
=
math
.
MaxFloat64
}
d
.
Duration
=
time
.
Duration
(
t
)
case
string
:
d
.
Duration
,
err
=
time
.
ParseDuration
(
t
)
if
err
!=
nil
{
return
err
}
}
return
nil
}
cmd/cmd.go
View file @
db77dfe0
...
...
@@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return
generateBatch
(
cmd
,
args
[
0
])
}
var
generateContextKey
str
uct
{}
type
generateContextKey
str
ing
func
generate
(
cmd
*
cobra
.
Command
,
model
,
prompt
string
)
error
{
if
len
(
strings
.
TrimSpace
(
prompt
))
>
0
{
...
...
@@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var
latest
api
.
GenerateResponse
generateContext
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
)
.
([]
int
)
generateContext
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
(
"context"
)
)
.
([]
int
)
if
!
ok
{
generateContext
=
[]
int
{}
}
request
:=
api
.
GenerateRequest
{
Model
:
model
,
Prompt
:
prompt
,
Context
:
generateContext
}
fn
:=
func
(
resp
api
.
GenerateResponse
)
error
{
generateSession
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
(
"session"
))
.
(
int64
)
if
!
ok
{
generateSession
=
0
}
request
:=
api
.
GenerateRequest
{
Model
:
model
,
Prompt
:
prompt
,
Context
:
generateContext
,
SessionID
:
generateSession
}
fn
:=
func
(
response
api
.
GenerateResponse
)
error
{
if
!
spinner
.
IsFinished
()
{
spinner
.
Finish
()
}
latest
=
resp
latest
=
resp
onse
fmt
.
Print
(
resp
.
Response
)
cmd
.
SetContext
(
context
.
WithValue
(
cmd
.
Context
(),
generateContextKey
,
resp
.
Context
))
fmt
.
Print
(
response
.
Response
)
return
nil
}
...
...
@@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if
verbose
{
latest
.
Summary
()
}
ctx
:=
cmd
.
Context
()
ctx
=
context
.
WithValue
(
ctx
,
generateContextKey
(
"context"
),
latest
.
Context
)
ctx
=
context
.
WithValue
(
ctx
,
generateContextKey
(
"session"
),
latest
.
SessionID
)
cmd
.
SetContext
(
ctx
)
}
return
nil
...
...
llama/llama.go
View file @
db77dfe0
package
llama
/*
#cgo CPPFLAGS: -O3 -DNDEBUG
=1
-DGGML_USE_K_QUANTS
#cgo CXXFLAGS: -std=
c
++11
#cgo CPPFLAGS: -O3
-Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable
-DNDEBUG -DGGML_USE_K_QUANTS
#cgo CXXFLAGS: -std=
gnu
++11
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#include <stdlib.h>
...
...
@@ -21,6 +21,7 @@ struct llama_sample_options
int mirostat;
float mirostat_tau;
float mirostat_eta;
bool penalize_newline;
};
llama_token llama_sample(
...
...
@@ -37,6 +38,8 @@ llama_token llama_sample(
false,
};
struct llama_token_data newline = candidates_p.data[llama_token_nl()];
llama_sample_repetition_penalty(
ctx, &candidates_p,
last_tokens, n_last_tokens,
...
...
@@ -47,6 +50,10 @@ llama_token llama_sample(
last_tokens, n_last_tokens,
opts->frequency_penalty, opts->presence_penalty);
if (!opts->penalize_newline) {
candidates_p.data[llama_token_nl()] = newline;
}
if (opts->temperature <= 0) {
return llama_sample_token_greedy(ctx, &candidates_p);
}
...
...
@@ -82,29 +89,37 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"strings"
"
time
"
"
sync
"
"unicode/utf8"
"unsafe"
"github.com/jmorganca/ollama/api"
)
type
llama
struct
{
type
LLM
struct
{
params
*
C
.
struct_llama_context_params
model
*
C
.
struct_llama_model
ctx
*
C
.
struct_llama_context
last
[]
C
.
llama_token
embd
[]
C
.
llama_token
cursor
int
mu
sync
.
Mutex
gc
bool
api
.
Options
}
func
New
(
model
string
,
opts
api
.
Options
)
(
*
llama
,
error
)
{
func
New
(
model
string
,
opts
api
.
Options
)
(
*
LLM
,
error
)
{
if
_
,
err
:=
os
.
Stat
(
model
);
err
!=
nil
{
return
nil
,
err
}
llm
:=
llama
{
Options
:
opts
}
llm
:=
LLM
{
Options
:
opts
}
C
.
llama_backend_init
(
C
.
bool
(
llm
.
UseNUMA
))
...
...
@@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
return
&
llm
,
nil
}
func
(
llm
*
llama
)
Close
()
{
func
(
llm
*
LLM
)
Close
()
{
llm
.
gc
=
true
llm
.
mu
.
Lock
()
defer
llm
.
mu
.
Unlock
()
defer
C
.
llama_free_model
(
llm
.
model
)
defer
C
.
llama_free
(
llm
.
ctx
)
C
.
llama_print_timings
(
llm
.
ctx
)
}
func
(
llm
*
llama
)
Predict
(
ctx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
if
input
:=
llm
.
tokenize
(
prompt
);
input
!=
nil
{
embd
:=
make
([]
C
.
llama_token
,
len
(
ctx
))
for
i
:=
range
ctx
{
embd
[
i
]
=
C
.
llama_token
(
ctx
[
i
])
func
(
llm
*
LLM
)
Predict
(
ctx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
C
.
llama_reset_timings
(
llm
.
ctx
)
tokens
:=
make
([]
C
.
llama_token
,
len
(
ctx
))
for
i
:=
range
tokens
{
tokens
[
i
]
=
C
.
llama_token
(
ctx
[
i
])
}
if
len
(
tokens
)
==
0
{
tokens
=
llm
.
tokenize
(
" "
)
}
llm
.
marshalPrompt
(
tokens
,
prompt
)
C
.
llama_set_rng_seed
(
llm
.
ctx
,
C
.
uint
(
llm
.
Seed
))
var
b
bytes
.
Buffer
for
{
token
,
err
:=
llm
.
next
()
if
llm
.
gc
{
return
nil
}
else
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
}
else
if
err
!=
nil
{
return
err
}
b
.
WriteString
(
llm
.
detokenize
(
token
))
if
utf8
.
Valid
(
b
.
Bytes
())
||
b
.
Len
()
>=
utf8
.
UTFMax
{
fn
(
api
.
GenerateResponse
{
Response
:
b
.
String
()})
b
.
Reset
()
}
}
return
llm
.
generate
(
append
(
embd
,
input
...
),
fn
)
last
:=
make
([]
int
,
0
,
len
(
llm
.
last
))
for
_
,
i
:=
range
llm
.
last
{
if
i
!=
0
{
last
=
append
(
last
,
int
(
i
))
}
}
return
errors
.
New
(
"llama: tokenize"
)
timings
:=
C
.
llama_get_timings
(
llm
.
ctx
)
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Context
:
last
,
SampleCount
:
int
(
timings
.
n_sample
),
SampleDuration
:
parseDurationMs
(
float64
(
timings
.
t_sample_ms
)),
PromptEvalCount
:
int
(
timings
.
n_p_eval
),
PromptEvalDuration
:
parseDurationMs
(
float64
(
timings
.
t_p_eval_ms
)),
EvalCount
:
int
(
timings
.
n_eval
),
EvalDuration
:
parseDurationMs
(
float64
(
timings
.
t_eval_ms
)),
})
return
nil
}
func
(
llm
*
llama
)
tokenize
(
prompt
string
)
[]
C
.
llama_token
{
func
(
llm
*
LLM
)
marshalPrompt
(
ctx
[]
C
.
llama_token
,
prompt
string
)
[]
C
.
llama_token
{
tokens
:=
append
(
ctx
,
llm
.
tokenize
(
prompt
)
...
)
if
llm
.
NumKeep
<
0
{
llm
.
NumKeep
=
len
(
tokens
)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if
llm
.
NumCtx
-
4
<
llm
.
NumKeep
{
llm
.
NumKeep
=
llm
.
NumCtx
-
4
}
if
len
(
tokens
)
>=
llm
.
NumCtx
{
// truncate input
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
tokens
[
:
llm
.
NumKeep
]
erasedBlocks
:=
(
len
(
tokens
)
-
llm
.
NumKeep
-
numLeft
-
1
)
/
numLeft
truncated
=
append
(
truncated
,
tokens
[
llm
.
NumKeep
+
erasedBlocks
*
numLeft
:
]
...
)
copy
(
llm
.
last
,
tokens
[
len
(
tokens
)
-
llm
.
NumCtx
:
])
tokens
=
truncated
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
))
}
else
{
llm
.
last
=
make
([]
C
.
llama_token
,
llm
.
NumCtx
-
len
(
tokens
))
llm
.
last
=
append
(
llm
.
last
,
tokens
...
)
}
var
i
int
for
i
=
0
;
i
<
len
(
llm
.
embd
)
&&
i
<
len
(
tokens
)
&&
llm
.
embd
[
i
]
==
tokens
[
i
];
i
++
{
// noop
}
llm
.
embd
=
tokens
if
i
==
len
(
tokens
)
{
// evaluate at least one token to generate logits
i
--
}
llm
.
cursor
=
i
log
.
Printf
(
"prompt: num_past=%d cached=%v eval=%v"
,
i
,
len
(
llm
.
embd
[
:
i
]),
len
(
llm
.
embd
[
i
:
]))
return
tokens
}
func
(
llm
*
LLM
)
tokenize
(
prompt
string
)
[]
C
.
llama_token
{
cPrompt
:=
C
.
CString
(
prompt
)
defer
C
.
free
(
unsafe
.
Pointer
(
cPrompt
))
...
...
@@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
return
nil
}
func
(
llm
*
llama
)
detokenize
(
tokens
...
C
.
llama_token
)
string
{
func
(
llm
*
LLM
)
detokenize
(
tokens
...
C
.
llama_token
)
string
{
var
sb
strings
.
Builder
for
_
,
token
:=
range
tokens
{
sb
.
WriteString
(
C
.
GoString
(
C
.
llama_token_to_str
(
llm
.
ctx
,
token
)))
...
...
@@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return
sb
.
String
()
}
func
(
llm
*
llama
)
generate
(
input
[]
C
.
llama_token
,
fn
func
(
api
.
GenerateResponse
))
error
{
var
opts
C
.
struct_llama_sample_options
opts
.
repeat_penalty
=
C
.
float
(
llm
.
RepeatPenalty
)
opts
.
frequency_penalty
=
C
.
float
(
llm
.
FrequencyPenalty
)
opts
.
presence_penalty
=
C
.
float
(
llm
.
PresencePenalty
)
opts
.
temperature
=
C
.
float
(
llm
.
Temperature
)
opts
.
top_k
=
C
.
int
(
llm
.
TopK
)
opts
.
top_p
=
C
.
float
(
llm
.
TopP
)
opts
.
tfs_z
=
C
.
float
(
llm
.
TFSZ
)
opts
.
typical_p
=
C
.
float
(
llm
.
TypicalP
)
opts
.
mirostat
=
C
.
int
(
llm
.
Mirostat
)
opts
.
mirostat_tau
=
C
.
float
(
llm
.
MirostatTau
)
opts
.
mirostat_eta
=
C
.
float
(
llm
.
MirostatEta
)
output
:=
deque
[
C
.
llama_token
]{
capacity
:
llm
.
NumCtx
}
context
:=
deque
[
int
]{
capacity
:
llm
.
NumCtx
/
2
}
for
_
,
in
:=
range
input
{
context
.
PushLeft
(
int
(
in
))
func
(
llm
*
LLM
)
next
()
(
C
.
llama_token
,
error
)
{
llm
.
mu
.
Lock
()
defer
llm
.
mu
.
Unlock
()
if
len
(
llm
.
embd
)
>=
llm
.
NumCtx
{
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
llm
.
embd
[
:
llm
.
NumKeep
]
truncated
=
append
(
truncated
,
llm
.
embd
[
len
(
llm
.
embd
)
-
numLeft
:
]
...
)
llm
.
embd
=
truncated
llm
.
cursor
=
llm
.
NumKeep
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
),
llm
.
cursor
)
}
var
b
bytes
.
Buffer
for
C
.
llama_get_kv_cache_token_count
(
llm
.
ctx
)
<
C
.
int
(
llm
.
NumCtx
)
{
if
retval
:=
C
.
llama_eval
(
llm
.
ctx
,
unsafe
.
SliceData
(
input
),
C
.
int
(
len
(
input
)),
C
.
llama_get_kv_cache_token_count
(
llm
.
ctx
),
C
.
int
(
llm
.
NumThread
));
retval
!=
0
{
return
errors
.
New
(
"llama: eval"
)
for
{
if
llm
.
gc
{
return
0
,
io
.
EOF
}
token
,
err
:=
llm
.
sample
(
output
,
&
opts
)
if
errors
.
Is
(
err
,
io
.
EOF
)
{
if
llm
.
cursor
>=
len
(
llm
.
embd
)
{
break
}
else
if
err
!=
nil
{
return
err
}
b
.
WriteString
(
llm
.
detokenize
(
token
))
if
utf8
.
Valid
(
b
.
Bytes
())
||
b
.
Len
()
>=
utf8
.
UTFMax
{
// call the callback
fn
(
api
.
GenerateResponse
{
Response
:
b
.
String
(),
})
output
.
PushLeft
(
token
)
context
.
PushLeft
(
int
(
token
))
b
.
Reset
()
numEval
:=
len
(
llm
.
embd
)
-
llm
.
cursor
if
numEval
>
llm
.
NumBatch
{
numEval
=
llm
.
NumBatch
}
input
=
[]
C
.
llama_token
{
token
}
}
dur
:=
func
(
ms
float64
)
time
.
Duration
{
d
,
err
:=
time
.
ParseDuration
(
fmt
.
Sprintf
(
"%fms"
,
ms
))
if
err
!=
nil
{
panic
(
err
)
if
retval
:=
C
.
llama_eval
(
llm
.
ctx
,
unsafe
.
SliceData
(
llm
.
embd
[
llm
.
cursor
:
]),
C
.
int
(
numEval
),
C
.
int
(
llm
.
cursor
),
C
.
int
(
llm
.
NumThread
));
retval
!=
0
{
return
0
,
fmt
.
Errorf
(
"llama_eval: %d"
,
retval
)
}
return
d
llm
.
cursor
+=
numEval
}
timings
:=
C
.
llama_get_timings
(
llm
.
ctx
)
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Context
:
context
.
Data
(),
PromptEvalCount
:
int
(
timings
.
n_p_eval
),
PromptEvalDuration
:
dur
(
float64
(
timings
.
t_p_eval_ms
)),
EvalCount
:
int
(
timings
.
n_eval
),
EvalDuration
:
dur
(
float64
(
timings
.
t_eval_ms
)),
}
)
return
nil
}
func
(
llm
*
llama
)
sample
(
output
deque
[
C
.
llama_token
],
opts
*
C
.
struct_llama_sample_options
)
(
C
.
llama_token
,
error
)
{
numVocab
:=
int
(
C
.
llama_n_vocab
(
llm
.
ctx
)
)
var
sampleOpts
C
.
struct_llama_sample_options
sampleOpts
.
repeat_penalty
=
C
.
float
(
llm
.
RepeatPenalty
)
sampleOpts
.
frequency_penalty
=
C
.
float
(
llm
.
FrequencyPenalty
)
sampleOpts
.
presence_penalty
=
C
.
float
(
llm
.
PresencePenalty
)
sampleOpts
.
temperature
=
C
.
float
(
llm
.
Temperature
)
sampleOpts
.
top_k
=
C
.
int
(
llm
.
TopK
)
sampleOpts
.
top_p
=
C
.
float
(
llm
.
TopP
)
sampleOpts
.
tfs_z
=
C
.
float
(
llm
.
TFSZ
)
sampleOpts
.
typical_p
=
C
.
float
(
llm
.
TypicalP
)
sampleOpts
.
mirostat
=
C
.
int
(
llm
.
Mirostat
)
sampleOpts
.
mirostat_tau
=
C
.
float
(
llm
.
MirostatTau
)
sampleOpts
.
mirostat_eta
=
C
.
float
(
llm
.
MirostatEta
)
sampleOpts
.
penalize_newline
=
C
.
bool
(
llm
.
PenalizeNewline
)
numVocab
:=
C
.
llama_n_vocab
(
llm
.
ctx
)
logits
:=
unsafe
.
Slice
(
C
.
llama_get_logits
(
llm
.
ctx
),
numVocab
)
candidates
:=
deque
[
C
.
struct_llama_token_data
]{
capacity
:
numVocab
}
for
i
:=
0
;
i
<
candidates
.
Cap
();
i
++
{
candidates
.
PushLeft
(
C
.
struct_llama_token_data
{
// TODO: logit bias
candidates
:=
make
([]
C
.
llama_token_data
,
numVocab
)
for
i
:=
range
logits
{
candidates
[
i
]
=
C
.
llama_token_data
{
id
:
C
.
int
(
i
),
logit
:
logits
[
i
],
p
:
0
,
}
)
}
}
repeatLastN
:=
llm
.
RepeatLastN
if
len
(
llm
.
last
)
<
repeatLastN
{
repeatLastN
=
len
(
llm
.
last
)
}
if
llm
.
NumCtx
<
repeatLastN
{
repeatLastN
=
llm
.
NumCtx
}
lastN
:=
llm
.
last
[
len
(
llm
.
last
)
-
repeatLastN
:
]
token
:=
C
.
llama_sample
(
llm
.
ctx
,
unsafe
.
SliceData
(
candidates
.
Data
()),
C
.
size_t
(
candidates
.
Len
()),
unsafe
.
SliceData
(
output
.
Data
()),
C
.
size_t
(
output
.
Len
()),
opts
)
if
token
!=
C
.
llama_token_eos
()
{
return
token
,
nil
unsafe
.
SliceData
(
candidates
),
C
.
size_t
(
len
(
candidates
)),
unsafe
.
SliceData
(
lastN
),
C
.
size_t
(
len
(
lastN
)),
&
sampleOpts
,
)
llm
.
last
=
append
(
llm
.
last
,
token
)
llm
.
embd
=
append
(
llm
.
embd
,
token
)
if
token
==
C
.
llama_token_eos
()
{
return
0
,
io
.
EOF
}
return
0
,
io
.
EOF
return
token
,
nil
}
llama/utils.go
View file @
db77dfe0
package
llama
type
node
[
T
any
]
struct
{
t
T
next
*
node
[
T
]
prev
*
node
[
T
]
}
type
deque
[
T
any
]
struct
{
head
*
node
[
T
]
tail
*
node
[
T
]
size
int
capacity
int
}
func
(
d
*
deque
[
T
])
Empty
()
bool
{
return
d
.
size
==
0
}
func
(
d
*
deque
[
T
])
Len
()
int
{
return
d
.
size
}
func
(
d
*
deque
[
T
])
Cap
()
int
{
return
d
.
capacity
}
func
(
d
*
deque
[
T
])
Push
(
t
T
)
{
if
d
.
capacity
>
0
&&
d
.
size
>=
d
.
capacity
{
d
.
PopLeft
()
}
n
:=
node
[
T
]{
t
:
t
}
if
d
.
head
!=
nil
{
n
.
next
=
d
.
head
d
.
head
.
prev
=
&
n
d
.
head
=
&
n
}
else
{
d
.
head
=
&
n
d
.
tail
=
&
n
}
d
.
size
++
}
func
(
d
*
deque
[
T
])
PushLeft
(
t
T
)
{
if
d
.
capacity
>
0
&&
d
.
size
>=
d
.
capacity
{
d
.
Pop
()
}
n
:=
node
[
T
]{
t
:
t
}
if
d
.
tail
!=
nil
{
n
.
prev
=
d
.
tail
d
.
tail
.
next
=
&
n
d
.
tail
=
&
n
}
else
{
d
.
head
=
&
n
d
.
tail
=
&
n
}
d
.
size
++
}
func
(
d
*
deque
[
T
])
Pop
()
*
T
{
if
d
.
Empty
()
{
return
nil
}
head
:=
d
.
head
d
.
head
=
head
.
next
if
d
.
head
!=
nil
{
d
.
head
.
prev
=
nil
}
else
{
d
.
tail
=
nil
}
d
.
size
--
return
&
head
.
t
}
func
(
d
*
deque
[
T
])
PopLeft
()
*
T
{
if
d
.
Empty
()
{
return
nil
}
tail
:=
d
.
tail
d
.
tail
=
tail
.
prev
if
d
.
tail
!=
nil
{
d
.
tail
.
next
=
nil
}
else
{
d
.
head
=
nil
}
d
.
size
--
return
&
tail
.
t
}
import
(
"fmt"
"time"
)
func
(
d
*
deque
[
T
])
Data
()
(
data
[]
T
)
{
for
n
:=
d
.
head
;
n
!=
nil
;
n
=
n
.
next
{
data
=
append
(
data
,
n
.
t
)
func
parseDurationMs
(
ms
float64
)
time
.
Duration
{
dur
,
err
:=
time
.
ParseDuration
(
fmt
.
Sprintf
(
"%fms"
,
ms
))
if
err
!=
nil
{
panic
(
err
)
}
return
d
ata
return
d
ur
}
server/routes.go
View file @
db77dfe0
...
...
@@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"dario.cat/mergo"
...
...
@@ -21,8 +22,21 @@ import (
"github.com/jmorganca/ollama/llama"
)
var
activeSession
struct
{
mu
sync
.
Mutex
id
int64
llm
*
llama
.
LLM
expireAt
time
.
Time
expireTimer
*
time
.
Timer
}
func
GenerateHandler
(
c
*
gin
.
Context
)
{
start
:=
time
.
Now
()
activeSession
.
mu
.
Lock
()
defer
activeSession
.
mu
.
Unlock
()
checkpointStart
:=
time
.
Now
()
var
req
api
.
GenerateRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
...
...
@@ -36,44 +50,85 @@ func GenerateHandler(c *gin.Context) {
return
}
opts
:=
api
.
DefaultOptions
()
if
err
:=
mergo
.
Merge
(
&
opts
,
model
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()}
)
return
}
if
req
.
SessionID
==
0
||
req
.
SessionID
!=
activeSession
.
id
{
if
activeSession
.
llm
!=
nil
{
activeSession
.
llm
.
Close
(
)
activeSession
.
llm
=
nil
}
if
err
:=
mergo
.
Merge
(
&
opts
,
req
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
opts
:=
api
.
DefaultOptions
()
if
err
:=
mergo
.
Merge
(
&
opts
,
model
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
if
err
:=
mergo
.
Merge
(
&
opts
,
req
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
llm
,
err
:=
llama
.
New
(
model
.
ModelPath
,
opts
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
activeSession
.
id
=
time
.
Now
()
.
UnixNano
()
activeSession
.
llm
=
llm
}
prompt
,
err
:=
model
.
Prompt
(
req
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
sessionDuration
:=
req
.
SessionDuration
sessionID
:=
activeSession
.
id
activeSession
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
.
Duration
)
if
activeSession
.
expireTimer
==
nil
{
activeSession
.
expireTimer
=
time
.
AfterFunc
(
sessionDuration
.
Duration
,
func
()
{
activeSession
.
mu
.
Lock
()
defer
activeSession
.
mu
.
Unlock
()
if
sessionID
!=
activeSession
.
id
{
return
}
if
time
.
Now
()
.
Before
(
activeSession
.
expireAt
)
{
return
}
activeSession
.
llm
.
Close
()
activeSession
.
llm
=
nil
activeSession
.
id
=
0
})
}
activeSession
.
expireTimer
.
Reset
(
sessionDuration
.
Duration
)
checkpointLoaded
:=
time
.
Now
()
llm
,
err
:=
llama
.
New
(
model
.
ModelPath
,
opts
)
prompt
,
err
:=
model
.
Prompt
(
req
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
defer
llm
.
Close
()
ch
:=
make
(
chan
any
)
go
func
()
{
defer
close
(
ch
)
fn
:=
func
(
r
api
.
GenerateResponse
)
{
activeSession
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
.
Duration
)
activeSession
.
expireTimer
.
Reset
(
sessionDuration
.
Duration
)
r
.
Model
=
req
.
Model
r
.
CreatedAt
=
time
.
Now
()
.
UTC
()
r
.
SessionID
=
activeSession
.
id
r
.
SessionExpiresAt
=
activeSession
.
expireAt
.
UTC
()
if
r
.
Done
{
r
.
TotalDuration
=
time
.
Since
(
start
)
r
.
TotalDuration
=
time
.
Since
(
checkpointStart
)
r
.
LoadDuration
=
checkpointLoaded
.
Sub
(
checkpointStart
)
}
ch
<-
r
}
if
err
:=
llm
.
Predict
(
req
.
Context
,
prompt
,
fn
);
err
!=
nil
{
if
err
:=
activeSession
.
llm
.
Predict
(
req
.
Context
,
prompt
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}()
...
...
@@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return
}
c
.
JSON
(
http
.
StatusOK
,
api
.
ListResponse
{
models
})
c
.
JSON
(
http
.
StatusOK
,
api
.
ListResponse
{
Models
:
models
})
}
func
CopyModelHandler
(
c
*
gin
.
Context
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment