Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
db77dfe0
Unverified
Commit
db77dfe0
authored
Jul 27, 2023
by
Michael Yang
Committed by
GitHub
Jul 27, 2023
Browse files
Merge pull request #102 from jmorganca/session-id
Session
parents
36ad90e8
688661ab
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
343 additions
and
212 deletions
+343
-212
api/types.go
api/types.go
+58
-2
cmd/cmd.go
cmd/cmd.go
+16
-8
llama/llama.go
llama/llama.go
+187
-86
llama/utils.go
llama/utils.go
+9
-98
server/routes.go
server/routes.go
+73
-18
No files found.
api/types.go
View file @
db77dfe0
package
api
package
api
import
(
import
(
"encoding/json"
"fmt"
"fmt"
"math"
"os"
"os"
"runtime"
"runtime"
"time"
"time"
...
@@ -28,6 +30,9 @@ func (e StatusError) Error() string {
...
@@ -28,6 +30,9 @@ func (e StatusError) Error() string {
}
}
type
GenerateRequest
struct
{
type
GenerateRequest
struct
{
SessionID
int64
`json:"session_id"`
SessionDuration
Duration
`json:"session_duration,omitempty"`
Model
string
`json:"model"`
Model
string
`json:"model"`
Prompt
string
`json:"prompt"`
Prompt
string
`json:"prompt"`
Context
[]
int
`json:"context,omitempty"`
Context
[]
int
`json:"context,omitempty"`
...
@@ -81,6 +86,9 @@ type ListResponseModel struct {
...
@@ -81,6 +86,9 @@ type ListResponseModel struct {
}
}
type
GenerateResponse
struct
{
type
GenerateResponse
struct
{
SessionID
int64
`json:"session_id"`
SessionExpiresAt
time
.
Time
`json:"session_expires_at"`
Model
string
`json:"model"`
Model
string
`json:"model"`
CreatedAt
time
.
Time
`json:"created_at"`
CreatedAt
time
.
Time
`json:"created_at"`
Response
string
`json:"response,omitempty"`
Response
string
`json:"response,omitempty"`
...
@@ -89,6 +97,9 @@ type GenerateResponse struct {
...
@@ -89,6 +97,9 @@ type GenerateResponse struct {
Context
[]
int
`json:"context,omitempty"`
Context
[]
int
`json:"context,omitempty"`
TotalDuration
time
.
Duration
`json:"total_duration,omitempty"`
TotalDuration
time
.
Duration
`json:"total_duration,omitempty"`
LoadDuration
time
.
Duration
`json:"load_duration,omitempty"`
SampleCount
int
`json:"sample_count,omitempty"`
SampleDuration
time
.
Duration
`json:"sample_duration,omitempty"`
PromptEvalCount
int
`json:"prompt_eval_count,omitempty"`
PromptEvalCount
int
`json:"prompt_eval_count,omitempty"`
PromptEvalDuration
time
.
Duration
`json:"prompt_eval_duration,omitempty"`
PromptEvalDuration
time
.
Duration
`json:"prompt_eval_duration,omitempty"`
EvalCount
int
`json:"eval_count,omitempty"`
EvalCount
int
`json:"eval_count,omitempty"`
...
@@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
...
@@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
fmt
.
Fprintf
(
os
.
Stderr
,
"total duration: %v
\n
"
,
r
.
TotalDuration
)
fmt
.
Fprintf
(
os
.
Stderr
,
"total duration: %v
\n
"
,
r
.
TotalDuration
)
}
}
if
r
.
LoadDuration
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"load duration: %v
\n
"
,
r
.
LoadDuration
)
}
if
r
.
SampleCount
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"sample count: %d token(s)
\n
"
,
r
.
SampleCount
)
}
if
r
.
SampleDuration
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"sample duration: %s
\n
"
,
r
.
SampleDuration
)
fmt
.
Fprintf
(
os
.
Stderr
,
"sample rate: %.2f tokens/s
\n
"
,
float64
(
r
.
SampleCount
)
/
r
.
SampleDuration
.
Seconds
())
}
if
r
.
PromptEvalCount
>
0
{
if
r
.
PromptEvalCount
>
0
{
fmt
.
Fprintf
(
os
.
Stderr
,
"prompt eval count: %d token(s)
\n
"
,
r
.
PromptEvalCount
)
fmt
.
Fprintf
(
os
.
Stderr
,
"prompt eval count: %d token(s)
\n
"
,
r
.
PromptEvalCount
)
}
}
...
@@ -127,6 +151,7 @@ type Options struct {
...
@@ -127,6 +151,7 @@ type Options struct {
// Model options
// Model options
NumCtx
int
`json:"num_ctx,omitempty"`
NumCtx
int
`json:"num_ctx,omitempty"`
NumKeep
int
`json:"num_keep,omitempty"`
NumBatch
int
`json:"num_batch,omitempty"`
NumBatch
int
`json:"num_batch,omitempty"`
NumGPU
int
`json:"num_gpu,omitempty"`
NumGPU
int
`json:"num_gpu,omitempty"`
MainGPU
int
`json:"main_gpu,omitempty"`
MainGPU
int
`json:"main_gpu,omitempty"`
...
@@ -151,6 +176,7 @@ type Options struct {
...
@@ -151,6 +176,7 @@ type Options struct {
Mirostat
int
`json:"mirostat,omitempty"`
Mirostat
int
`json:"mirostat,omitempty"`
MirostatTau
float32
`json:"mirostat_tau,omitempty"`
MirostatTau
float32
`json:"mirostat_tau,omitempty"`
MirostatEta
float32
`json:"mirostat_eta,omitempty"`
MirostatEta
float32
`json:"mirostat_eta,omitempty"`
PenalizeNewline
bool
`json:"penalize_newline,omitempty"`
NumThread
int
`json:"num_thread,omitempty"`
NumThread
int
`json:"num_thread,omitempty"`
}
}
...
@@ -162,14 +188,14 @@ func DefaultOptions() Options {
...
@@ -162,14 +188,14 @@ func DefaultOptions() Options {
UseNUMA
:
false
,
UseNUMA
:
false
,
NumCtx
:
2048
,
NumCtx
:
2048
,
NumBatch
:
512
,
NumBatch
:
1024
,
NumGPU
:
1
,
NumGPU
:
1
,
LowVRAM
:
false
,
LowVRAM
:
false
,
F16KV
:
true
,
F16KV
:
true
,
UseMMap
:
true
,
UseMMap
:
true
,
UseMLock
:
false
,
UseMLock
:
false
,
RepeatLastN
:
512
,
RepeatLastN
:
64
,
RepeatPenalty
:
1.1
,
RepeatPenalty
:
1.1
,
FrequencyPenalty
:
0.0
,
FrequencyPenalty
:
0.0
,
PresencePenalty
:
0.0
,
PresencePenalty
:
0.0
,
...
@@ -181,7 +207,37 @@ func DefaultOptions() Options {
...
@@ -181,7 +207,37 @@ func DefaultOptions() Options {
Mirostat
:
0
,
Mirostat
:
0
,
MirostatTau
:
5.0
,
MirostatTau
:
5.0
,
MirostatEta
:
0.1
,
MirostatEta
:
0.1
,
PenalizeNewline
:
true
,
NumThread
:
runtime
.
NumCPU
(),
NumThread
:
runtime
.
NumCPU
(),
}
}
}
}
type
Duration
struct
{
time
.
Duration
}
func
(
d
*
Duration
)
UnmarshalJSON
(
b
[]
byte
)
(
err
error
)
{
var
v
any
if
err
:=
json
.
Unmarshal
(
b
,
&
v
);
err
!=
nil
{
return
err
}
d
.
Duration
=
5
*
time
.
Minute
switch
t
:=
v
.
(
type
)
{
case
float64
:
if
t
<
0
{
t
=
math
.
MaxFloat64
}
d
.
Duration
=
time
.
Duration
(
t
)
case
string
:
d
.
Duration
,
err
=
time
.
ParseDuration
(
t
)
if
err
!=
nil
{
return
err
}
}
return
nil
}
cmd/cmd.go
View file @
db77dfe0
...
@@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
...
@@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return
generateBatch
(
cmd
,
args
[
0
])
return
generateBatch
(
cmd
,
args
[
0
])
}
}
var
generateContextKey
str
uct
{}
type
generateContextKey
str
ing
func
generate
(
cmd
*
cobra
.
Command
,
model
,
prompt
string
)
error
{
func
generate
(
cmd
*
cobra
.
Command
,
model
,
prompt
string
)
error
{
if
len
(
strings
.
TrimSpace
(
prompt
))
>
0
{
if
len
(
strings
.
TrimSpace
(
prompt
))
>
0
{
...
@@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
...
@@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var
latest
api
.
GenerateResponse
var
latest
api
.
GenerateResponse
generateContext
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
)
.
([]
int
)
generateContext
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
(
"context"
)
)
.
([]
int
)
if
!
ok
{
if
!
ok
{
generateContext
=
[]
int
{}
generateContext
=
[]
int
{}
}
}
request
:=
api
.
GenerateRequest
{
Model
:
model
,
Prompt
:
prompt
,
Context
:
generateContext
}
generateSession
,
ok
:=
cmd
.
Context
()
.
Value
(
generateContextKey
(
"session"
))
.
(
int64
)
fn
:=
func
(
resp
api
.
GenerateResponse
)
error
{
if
!
ok
{
generateSession
=
0
}
request
:=
api
.
GenerateRequest
{
Model
:
model
,
Prompt
:
prompt
,
Context
:
generateContext
,
SessionID
:
generateSession
}
fn
:=
func
(
response
api
.
GenerateResponse
)
error
{
if
!
spinner
.
IsFinished
()
{
if
!
spinner
.
IsFinished
()
{
spinner
.
Finish
()
spinner
.
Finish
()
}
}
latest
=
resp
latest
=
resp
onse
fmt
.
Print
(
resp
.
Response
)
fmt
.
Print
(
response
.
Response
)
cmd
.
SetContext
(
context
.
WithValue
(
cmd
.
Context
(),
generateContextKey
,
resp
.
Context
))
return
nil
return
nil
}
}
...
@@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
...
@@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if
verbose
{
if
verbose
{
latest
.
Summary
()
latest
.
Summary
()
}
}
ctx
:=
cmd
.
Context
()
ctx
=
context
.
WithValue
(
ctx
,
generateContextKey
(
"context"
),
latest
.
Context
)
ctx
=
context
.
WithValue
(
ctx
,
generateContextKey
(
"session"
),
latest
.
SessionID
)
cmd
.
SetContext
(
ctx
)
}
}
return
nil
return
nil
...
...
llama/llama.go
View file @
db77dfe0
package
llama
package
llama
/*
/*
#cgo CPPFLAGS: -O3 -DNDEBUG
=1
-DGGML_USE_K_QUANTS
#cgo CPPFLAGS: -O3
-Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable
-DNDEBUG -DGGML_USE_K_QUANTS
#cgo CXXFLAGS: -std=
c
++11
#cgo CXXFLAGS: -std=
gnu
++11
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#include <stdlib.h>
#include <stdlib.h>
...
@@ -21,6 +21,7 @@ struct llama_sample_options
...
@@ -21,6 +21,7 @@ struct llama_sample_options
int mirostat;
int mirostat;
float mirostat_tau;
float mirostat_tau;
float mirostat_eta;
float mirostat_eta;
bool penalize_newline;
};
};
llama_token llama_sample(
llama_token llama_sample(
...
@@ -37,6 +38,8 @@ llama_token llama_sample(
...
@@ -37,6 +38,8 @@ llama_token llama_sample(
false,
false,
};
};
struct llama_token_data newline = candidates_p.data[llama_token_nl()];
llama_sample_repetition_penalty(
llama_sample_repetition_penalty(
ctx, &candidates_p,
ctx, &candidates_p,
last_tokens, n_last_tokens,
last_tokens, n_last_tokens,
...
@@ -47,6 +50,10 @@ llama_token llama_sample(
...
@@ -47,6 +50,10 @@ llama_token llama_sample(
last_tokens, n_last_tokens,
last_tokens, n_last_tokens,
opts->frequency_penalty, opts->presence_penalty);
opts->frequency_penalty, opts->presence_penalty);
if (!opts->penalize_newline) {
candidates_p.data[llama_token_nl()] = newline;
}
if (opts->temperature <= 0) {
if (opts->temperature <= 0) {
return llama_sample_token_greedy(ctx, &candidates_p);
return llama_sample_token_greedy(ctx, &candidates_p);
}
}
...
@@ -82,29 +89,37 @@ import (
...
@@ -82,29 +89,37 @@ import (
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
"log"
"os"
"os"
"strings"
"strings"
"
time
"
"
sync
"
"unicode/utf8"
"unicode/utf8"
"unsafe"
"unsafe"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/api"
)
)
type
llama
struct
{
type
LLM
struct
{
params
*
C
.
struct_llama_context_params
params
*
C
.
struct_llama_context_params
model
*
C
.
struct_llama_model
model
*
C
.
struct_llama_model
ctx
*
C
.
struct_llama_context
ctx
*
C
.
struct_llama_context
last
[]
C
.
llama_token
embd
[]
C
.
llama_token
cursor
int
mu
sync
.
Mutex
gc
bool
api
.
Options
api
.
Options
}
}
func
New
(
model
string
,
opts
api
.
Options
)
(
*
llama
,
error
)
{
func
New
(
model
string
,
opts
api
.
Options
)
(
*
LLM
,
error
)
{
if
_
,
err
:=
os
.
Stat
(
model
);
err
!=
nil
{
if
_
,
err
:=
os
.
Stat
(
model
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
llm
:=
llama
{
Options
:
opts
}
llm
:=
LLM
{
Options
:
opts
}
C
.
llama_backend_init
(
C
.
bool
(
llm
.
UseNUMA
))
C
.
llama_backend_init
(
C
.
bool
(
llm
.
UseNUMA
))
...
@@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
...
@@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
return
&
llm
,
nil
return
&
llm
,
nil
}
}
func
(
llm
*
llama
)
Close
()
{
func
(
llm
*
LLM
)
Close
()
{
llm
.
gc
=
true
llm
.
mu
.
Lock
()
defer
llm
.
mu
.
Unlock
()
defer
C
.
llama_free_model
(
llm
.
model
)
defer
C
.
llama_free_model
(
llm
.
model
)
defer
C
.
llama_free
(
llm
.
ctx
)
defer
C
.
llama_free
(
llm
.
ctx
)
C
.
llama_print_timings
(
llm
.
ctx
)
C
.
llama_print_timings
(
llm
.
ctx
)
}
}
func
(
llm
*
llama
)
Predict
(
ctx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
func
(
llm
*
LLM
)
Predict
(
ctx
[]
int
,
prompt
string
,
fn
func
(
api
.
GenerateResponse
))
error
{
if
input
:=
llm
.
tokenize
(
prompt
);
input
!=
nil
{
C
.
llama_reset_timings
(
llm
.
ctx
)
embd
:=
make
([]
C
.
llama_token
,
len
(
ctx
))
for
i
:=
range
ctx
{
tokens
:=
make
([]
C
.
llama_token
,
len
(
ctx
))
embd
[
i
]
=
C
.
llama_token
(
ctx
[
i
])
for
i
:=
range
tokens
{
tokens
[
i
]
=
C
.
llama_token
(
ctx
[
i
])
}
if
len
(
tokens
)
==
0
{
tokens
=
llm
.
tokenize
(
" "
)
}
llm
.
marshalPrompt
(
tokens
,
prompt
)
C
.
llama_set_rng_seed
(
llm
.
ctx
,
C
.
uint
(
llm
.
Seed
))
var
b
bytes
.
Buffer
for
{
token
,
err
:=
llm
.
next
()
if
llm
.
gc
{
return
nil
}
else
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
}
else
if
err
!=
nil
{
return
err
}
b
.
WriteString
(
llm
.
detokenize
(
token
))
if
utf8
.
Valid
(
b
.
Bytes
())
||
b
.
Len
()
>=
utf8
.
UTFMax
{
fn
(
api
.
GenerateResponse
{
Response
:
b
.
String
()})
b
.
Reset
()
}
}
last
:=
make
([]
int
,
0
,
len
(
llm
.
last
))
for
_
,
i
:=
range
llm
.
last
{
if
i
!=
0
{
last
=
append
(
last
,
int
(
i
))
}
}
timings
:=
C
.
llama_get_timings
(
llm
.
ctx
)
fn
(
api
.
GenerateResponse
{
Done
:
true
,
Context
:
last
,
SampleCount
:
int
(
timings
.
n_sample
),
SampleDuration
:
parseDurationMs
(
float64
(
timings
.
t_sample_ms
)),
PromptEvalCount
:
int
(
timings
.
n_p_eval
),
PromptEvalDuration
:
parseDurationMs
(
float64
(
timings
.
t_p_eval_ms
)),
EvalCount
:
int
(
timings
.
n_eval
),
EvalDuration
:
parseDurationMs
(
float64
(
timings
.
t_eval_ms
)),
})
return
nil
}
func
(
llm
*
LLM
)
marshalPrompt
(
ctx
[]
C
.
llama_token
,
prompt
string
)
[]
C
.
llama_token
{
tokens
:=
append
(
ctx
,
llm
.
tokenize
(
prompt
)
...
)
if
llm
.
NumKeep
<
0
{
llm
.
NumKeep
=
len
(
tokens
)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if
llm
.
NumCtx
-
4
<
llm
.
NumKeep
{
llm
.
NumKeep
=
llm
.
NumCtx
-
4
}
if
len
(
tokens
)
>=
llm
.
NumCtx
{
// truncate input
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
truncated
:=
tokens
[
:
llm
.
NumKeep
]
erasedBlocks
:=
(
len
(
tokens
)
-
llm
.
NumKeep
-
numLeft
-
1
)
/
numLeft
truncated
=
append
(
truncated
,
tokens
[
llm
.
NumKeep
+
erasedBlocks
*
numLeft
:
]
...
)
copy
(
llm
.
last
,
tokens
[
len
(
tokens
)
-
llm
.
NumCtx
:
])
tokens
=
truncated
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
))
}
else
{
llm
.
last
=
make
([]
C
.
llama_token
,
llm
.
NumCtx
-
len
(
tokens
))
llm
.
last
=
append
(
llm
.
last
,
tokens
...
)
}
var
i
int
for
i
=
0
;
i
<
len
(
llm
.
embd
)
&&
i
<
len
(
tokens
)
&&
llm
.
embd
[
i
]
==
tokens
[
i
];
i
++
{
// noop
}
}
return
llm
.
generate
(
append
(
embd
,
input
...
),
fn
)
llm
.
embd
=
tokens
if
i
==
len
(
tokens
)
{
// evaluate at least one token to generate logits
i
--
}
}
return
errors
.
New
(
"llama: tokenize"
)
llm
.
cursor
=
i
log
.
Printf
(
"prompt: num_past=%d cached=%v eval=%v"
,
i
,
len
(
llm
.
embd
[
:
i
]),
len
(
llm
.
embd
[
i
:
]))
return
tokens
}
}
func
(
llm
*
llama
)
tokenize
(
prompt
string
)
[]
C
.
llama_token
{
func
(
llm
*
LLM
)
tokenize
(
prompt
string
)
[]
C
.
llama_token
{
cPrompt
:=
C
.
CString
(
prompt
)
cPrompt
:=
C
.
CString
(
prompt
)
defer
C
.
free
(
unsafe
.
Pointer
(
cPrompt
))
defer
C
.
free
(
unsafe
.
Pointer
(
cPrompt
))
...
@@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
...
@@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
return
nil
return
nil
}
}
func
(
llm
*
llama
)
detokenize
(
tokens
...
C
.
llama_token
)
string
{
func
(
llm
*
LLM
)
detokenize
(
tokens
...
C
.
llama_token
)
string
{
var
sb
strings
.
Builder
var
sb
strings
.
Builder
for
_
,
token
:=
range
tokens
{
for
_
,
token
:=
range
tokens
{
sb
.
WriteString
(
C
.
GoString
(
C
.
llama_token_to_str
(
llm
.
ctx
,
token
)))
sb
.
WriteString
(
C
.
GoString
(
C
.
llama_token_to_str
(
llm
.
ctx
,
token
)))
...
@@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
...
@@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return
sb
.
String
()
return
sb
.
String
()
}
}
func
(
llm
*
llama
)
generate
(
input
[]
C
.
llama_token
,
fn
func
(
api
.
GenerateResponse
))
error
{
func
(
llm
*
LLM
)
next
()
(
C
.
llama_token
,
error
)
{
var
opts
C
.
struct_llama_sample_options
llm
.
mu
.
Lock
()
opts
.
repeat_penalty
=
C
.
float
(
llm
.
RepeatPenalty
)
defer
llm
.
mu
.
Unlock
()
opts
.
frequency_penalty
=
C
.
float
(
llm
.
FrequencyPenalty
)
opts
.
presence_penalty
=
C
.
float
(
llm
.
PresencePenalty
)
opts
.
temperature
=
C
.
float
(
llm
.
Temperature
)
opts
.
top_k
=
C
.
int
(
llm
.
TopK
)
opts
.
top_p
=
C
.
float
(
llm
.
TopP
)
opts
.
tfs_z
=
C
.
float
(
llm
.
TFSZ
)
opts
.
typical_p
=
C
.
float
(
llm
.
TypicalP
)
opts
.
mirostat
=
C
.
int
(
llm
.
Mirostat
)
opts
.
mirostat_tau
=
C
.
float
(
llm
.
MirostatTau
)
opts
.
mirostat_eta
=
C
.
float
(
llm
.
MirostatEta
)
output
:=
deque
[
C
.
llama_token
]{
capacity
:
llm
.
NumCtx
}
context
:=
deque
[
int
]{
capacity
:
llm
.
NumCtx
/
2
}
for
_
,
in
:=
range
input
{
context
.
PushLeft
(
int
(
in
))
}
var
b
bytes
.
Buffer
if
len
(
llm
.
embd
)
>=
llm
.
NumCtx
{
for
C
.
llama_get_kv_cache_token_count
(
llm
.
ctx
)
<
C
.
int
(
llm
.
NumCtx
)
{
numLeft
:=
(
llm
.
NumCtx
-
llm
.
NumKeep
)
/
2
if
retval
:=
C
.
llama_eval
(
llm
.
ctx
,
unsafe
.
SliceData
(
input
),
C
.
int
(
len
(
input
)),
C
.
llama_get_kv_cache_token_count
(
llm
.
ctx
),
C
.
int
(
llm
.
NumThread
));
retval
!=
0
{
truncated
:=
llm
.
embd
[
:
llm
.
NumKeep
]
return
errors
.
New
(
"llama: eval"
)
truncated
=
append
(
truncated
,
llm
.
embd
[
len
(
llm
.
embd
)
-
numLeft
:
]
...
)
}
token
,
err
:=
llm
.
sample
(
output
,
&
opts
)
llm
.
embd
=
truncated
if
errors
.
Is
(
err
,
io
.
EOF
)
{
llm
.
cursor
=
llm
.
NumKeep
break
log
.
Printf
(
"input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d"
,
llm
.
NumCtx
,
llm
.
NumKeep
,
numLeft
,
len
(
truncated
),
llm
.
cursor
)
}
else
if
err
!=
nil
{
return
err
}
}
b
.
WriteString
(
llm
.
detokenize
(
token
))
for
{
if
utf8
.
Valid
(
b
.
Bytes
())
||
b
.
Len
()
>=
utf8
.
UTFMax
{
if
llm
.
gc
{
// call the callback
return
0
,
io
.
EOF
fn
(
api
.
GenerateResponse
{
Response
:
b
.
String
(),
})
output
.
PushLeft
(
token
)
context
.
PushLeft
(
int
(
token
))
b
.
Reset
()
}
}
input
=
[]
C
.
llama_token
{
token
}
if
llm
.
cursor
>=
len
(
llm
.
embd
)
{
break
}
}
dur
:=
func
(
ms
float64
)
time
.
Duration
{
numEval
:=
len
(
llm
.
embd
)
-
llm
.
cursor
d
,
err
:=
time
.
ParseDuration
(
fmt
.
Sprintf
(
"%fms"
,
ms
))
if
numEval
>
llm
.
NumBatch
{
if
err
!=
nil
{
numEval
=
llm
.
NumBatch
panic
(
err
)
}
}
return
d
if
retval
:=
C
.
llama_eval
(
llm
.
ctx
,
unsafe
.
SliceData
(
llm
.
embd
[
llm
.
cursor
:
]),
C
.
int
(
numEval
),
C
.
int
(
llm
.
cursor
),
C
.
int
(
llm
.
NumThread
));
retval
!=
0
{
return
0
,
fmt
.
Errorf
(
"llama_eval: %d"
,
retval
)
}
}
timings
:=
C
.
llama_get_timings
(
llm
.
ctx
)
llm
.
cursor
+=
numEval
fn
(
api
.
GenerateResponse
{
}
Done
:
true
,
Context
:
context
.
Data
(),
PromptEvalCount
:
int
(
timings
.
n_p_eval
),
PromptEvalDuration
:
dur
(
float64
(
timings
.
t_p_eval_ms
)),
EvalCount
:
int
(
timings
.
n_eval
),
EvalDuration
:
dur
(
float64
(
timings
.
t_eval_ms
)),
})
return
nil
}
func
(
llm
*
llama
)
sample
(
output
deque
[
C
.
llama_token
],
opts
*
C
.
struct_llama_sample_options
)
(
C
.
llama_token
,
error
)
{
var
sampleOpts
C
.
struct_llama_sample_options
numVocab
:=
int
(
C
.
llama_n_vocab
(
llm
.
ctx
))
sampleOpts
.
repeat_penalty
=
C
.
float
(
llm
.
RepeatPenalty
)
sampleOpts
.
frequency_penalty
=
C
.
float
(
llm
.
FrequencyPenalty
)
sampleOpts
.
presence_penalty
=
C
.
float
(
llm
.
PresencePenalty
)
sampleOpts
.
temperature
=
C
.
float
(
llm
.
Temperature
)
sampleOpts
.
top_k
=
C
.
int
(
llm
.
TopK
)
sampleOpts
.
top_p
=
C
.
float
(
llm
.
TopP
)
sampleOpts
.
tfs_z
=
C
.
float
(
llm
.
TFSZ
)
sampleOpts
.
typical_p
=
C
.
float
(
llm
.
TypicalP
)
sampleOpts
.
mirostat
=
C
.
int
(
llm
.
Mirostat
)
sampleOpts
.
mirostat_tau
=
C
.
float
(
llm
.
MirostatTau
)
sampleOpts
.
mirostat_eta
=
C
.
float
(
llm
.
MirostatEta
)
sampleOpts
.
penalize_newline
=
C
.
bool
(
llm
.
PenalizeNewline
)
numVocab
:=
C
.
llama_n_vocab
(
llm
.
ctx
)
logits
:=
unsafe
.
Slice
(
C
.
llama_get_logits
(
llm
.
ctx
),
numVocab
)
logits
:=
unsafe
.
Slice
(
C
.
llama_get_logits
(
llm
.
ctx
),
numVocab
)
candidates
:=
deque
[
C
.
struct_llama_token_data
]{
capacity
:
numVocab
}
// TODO: logit bias
for
i
:=
0
;
i
<
candidates
.
Cap
();
i
++
{
candidates
.
PushLeft
(
C
.
struct_llama_token_data
{
candidates
:=
make
([]
C
.
llama_token_data
,
numVocab
)
for
i
:=
range
logits
{
candidates
[
i
]
=
C
.
llama_token_data
{
id
:
C
.
int
(
i
),
id
:
C
.
int
(
i
),
logit
:
logits
[
i
],
logit
:
logits
[
i
],
p
:
0
,
p
:
0
,
}
)
}
}
}
repeatLastN
:=
llm
.
RepeatLastN
if
len
(
llm
.
last
)
<
repeatLastN
{
repeatLastN
=
len
(
llm
.
last
)
}
if
llm
.
NumCtx
<
repeatLastN
{
repeatLastN
=
llm
.
NumCtx
}
lastN
:=
llm
.
last
[
len
(
llm
.
last
)
-
repeatLastN
:
]
token
:=
C
.
llama_sample
(
token
:=
C
.
llama_sample
(
llm
.
ctx
,
llm
.
ctx
,
unsafe
.
SliceData
(
candidates
.
Data
()),
C
.
size_t
(
candidates
.
Len
()),
unsafe
.
SliceData
(
candidates
),
C
.
size_t
(
len
(
candidates
)),
unsafe
.
SliceData
(
output
.
Data
()),
C
.
size_t
(
output
.
Len
()),
unsafe
.
SliceData
(
lastN
),
C
.
size_t
(
len
(
lastN
)),
opts
)
&
sampleOpts
,
if
token
!=
C
.
llama_token_eos
()
{
)
return
token
,
nil
}
llm
.
last
=
append
(
llm
.
last
,
token
)
llm
.
embd
=
append
(
llm
.
embd
,
token
)
if
token
==
C
.
llama_token_eos
()
{
return
0
,
io
.
EOF
return
0
,
io
.
EOF
}
return
token
,
nil
}
}
llama/utils.go
View file @
db77dfe0
package
llama
package
llama
type
node
[
T
any
]
struct
{
import
(
t
T
"fmt"
next
*
node
[
T
]
"time"
prev
*
node
[
T
]
)
}
type
deque
[
T
any
]
struct
{
head
*
node
[
T
]
tail
*
node
[
T
]
size
int
capacity
int
}
func
(
d
*
deque
[
T
])
Empty
()
bool
{
return
d
.
size
==
0
}
func
(
d
*
deque
[
T
])
Len
()
int
{
return
d
.
size
}
func
(
d
*
deque
[
T
])
Cap
()
int
{
return
d
.
capacity
}
func
(
d
*
deque
[
T
])
Push
(
t
T
)
{
if
d
.
capacity
>
0
&&
d
.
size
>=
d
.
capacity
{
d
.
PopLeft
()
}
n
:=
node
[
T
]{
t
:
t
}
if
d
.
head
!=
nil
{
n
.
next
=
d
.
head
d
.
head
.
prev
=
&
n
d
.
head
=
&
n
}
else
{
d
.
head
=
&
n
d
.
tail
=
&
n
}
d
.
size
++
}
func
(
d
*
deque
[
T
])
PushLeft
(
t
T
)
{
if
d
.
capacity
>
0
&&
d
.
size
>=
d
.
capacity
{
d
.
Pop
()
}
n
:=
node
[
T
]{
t
:
t
}
if
d
.
tail
!=
nil
{
n
.
prev
=
d
.
tail
d
.
tail
.
next
=
&
n
d
.
tail
=
&
n
}
else
{
d
.
head
=
&
n
d
.
tail
=
&
n
}
d
.
size
++
}
func
(
d
*
deque
[
T
])
Pop
()
*
T
{
if
d
.
Empty
()
{
return
nil
}
head
:=
d
.
head
d
.
head
=
head
.
next
if
d
.
head
!=
nil
{
d
.
head
.
prev
=
nil
}
else
{
d
.
tail
=
nil
}
d
.
size
--
return
&
head
.
t
}
func
(
d
*
deque
[
T
])
PopLeft
()
*
T
{
if
d
.
Empty
()
{
return
nil
}
tail
:=
d
.
tail
d
.
tail
=
tail
.
prev
if
d
.
tail
!=
nil
{
d
.
tail
.
next
=
nil
}
else
{
d
.
head
=
nil
}
d
.
size
--
return
&
tail
.
t
}
func
(
d
*
deque
[
T
])
Data
()
(
data
[]
T
)
{
func
parseDurationMs
(
ms
float64
)
time
.
Duration
{
for
n
:=
d
.
head
;
n
!=
nil
;
n
=
n
.
next
{
dur
,
err
:=
time
.
ParseDuration
(
fmt
.
Sprintf
(
"%fms"
,
ms
))
data
=
append
(
data
,
n
.
t
)
if
err
!=
nil
{
panic
(
err
)
}
}
return
d
ata
return
d
ur
}
}
server/routes.go
View file @
db77dfe0
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"os"
"os"
"path/filepath"
"path/filepath"
"strings"
"strings"
"sync"
"time"
"time"
"dario.cat/mergo"
"dario.cat/mergo"
...
@@ -21,8 +22,21 @@ import (
...
@@ -21,8 +22,21 @@ import (
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/llama"
)
)
var
activeSession
struct
{
mu
sync
.
Mutex
id
int64
llm
*
llama
.
LLM
expireAt
time
.
Time
expireTimer
*
time
.
Timer
}
func
GenerateHandler
(
c
*
gin
.
Context
)
{
func
GenerateHandler
(
c
*
gin
.
Context
)
{
start
:=
time
.
Now
()
activeSession
.
mu
.
Lock
()
defer
activeSession
.
mu
.
Unlock
()
checkpointStart
:=
time
.
Now
()
var
req
api
.
GenerateRequest
var
req
api
.
GenerateRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
...
@@ -36,6 +50,12 @@ func GenerateHandler(c *gin.Context) {
...
@@ -36,6 +50,12 @@ func GenerateHandler(c *gin.Context) {
return
return
}
}
if
req
.
SessionID
==
0
||
req
.
SessionID
!=
activeSession
.
id
{
if
activeSession
.
llm
!=
nil
{
activeSession
.
llm
.
Close
()
activeSession
.
llm
=
nil
}
opts
:=
api
.
DefaultOptions
()
opts
:=
api
.
DefaultOptions
()
if
err
:=
mergo
.
Merge
(
&
opts
,
model
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
if
err
:=
mergo
.
Merge
(
&
opts
,
model
.
Options
,
mergo
.
WithOverride
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
...
@@ -47,33 +67,68 @@ func GenerateHandler(c *gin.Context) {
...
@@ -47,33 +67,68 @@ func GenerateHandler(c *gin.Context) {
return
return
}
}
prompt
,
err
:=
model
.
Prompt
(
req
)
llm
,
err
:=
llama
.
New
(
model
.
ModelPath
,
opts
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
llm
,
err
:=
llama
.
New
(
model
.
ModelPath
,
opts
)
activeSession
.
id
=
time
.
Now
()
.
UnixNano
()
activeSession
.
llm
=
llm
}
sessionDuration
:=
req
.
SessionDuration
sessionID
:=
activeSession
.
id
activeSession
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
.
Duration
)
if
activeSession
.
expireTimer
==
nil
{
activeSession
.
expireTimer
=
time
.
AfterFunc
(
sessionDuration
.
Duration
,
func
()
{
activeSession
.
mu
.
Lock
()
defer
activeSession
.
mu
.
Unlock
()
if
sessionID
!=
activeSession
.
id
{
return
}
if
time
.
Now
()
.
Before
(
activeSession
.
expireAt
)
{
return
}
activeSession
.
llm
.
Close
()
activeSession
.
llm
=
nil
activeSession
.
id
=
0
})
}
activeSession
.
expireTimer
.
Reset
(
sessionDuration
.
Duration
)
checkpointLoaded
:=
time
.
Now
()
prompt
,
err
:=
model
.
Prompt
(
req
)
if
err
!=
nil
{
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
return
}
}
defer
llm
.
Close
()
ch
:=
make
(
chan
any
)
ch
:=
make
(
chan
any
)
go
func
()
{
go
func
()
{
defer
close
(
ch
)
defer
close
(
ch
)
fn
:=
func
(
r
api
.
GenerateResponse
)
{
fn
:=
func
(
r
api
.
GenerateResponse
)
{
activeSession
.
expireAt
=
time
.
Now
()
.
Add
(
sessionDuration
.
Duration
)
activeSession
.
expireTimer
.
Reset
(
sessionDuration
.
Duration
)
r
.
Model
=
req
.
Model
r
.
Model
=
req
.
Model
r
.
CreatedAt
=
time
.
Now
()
.
UTC
()
r
.
CreatedAt
=
time
.
Now
()
.
UTC
()
r
.
SessionID
=
activeSession
.
id
r
.
SessionExpiresAt
=
activeSession
.
expireAt
.
UTC
()
if
r
.
Done
{
if
r
.
Done
{
r
.
TotalDuration
=
time
.
Since
(
start
)
r
.
TotalDuration
=
time
.
Since
(
checkpointStart
)
r
.
LoadDuration
=
checkpointLoaded
.
Sub
(
checkpointStart
)
}
}
ch
<-
r
ch
<-
r
}
}
if
err
:=
llm
.
Predict
(
req
.
Context
,
prompt
,
fn
);
err
!=
nil
{
if
err
:=
activeSession
.
llm
.
Predict
(
req
.
Context
,
prompt
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}
}()
}()
...
@@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
...
@@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return
return
}
}
c
.
JSON
(
http
.
StatusOK
,
api
.
ListResponse
{
models
})
c
.
JSON
(
http
.
StatusOK
,
api
.
ListResponse
{
Models
:
models
})
}
}
func
CopyModelHandler
(
c
*
gin
.
Context
)
{
func
CopyModelHandler
(
c
*
gin
.
Context
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment