Unverified Commit 4c4c730a authored by mraiser's avatar mraiser Committed by GitHub
Browse files

Merge branch 'ollama:main' into main

parents 6eb3cddc e02ecfb6
...@@ -61,3 +61,38 @@ PARAMETER param1 ...@@ -61,3 +61,38 @@ PARAMETER param1
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorContains(t, err, "missing value for [param1]")
} }
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}
...@@ -13,3 +13,13 @@ docker build \ ...@@ -13,3 +13,13 @@ docker build \
-f Dockerfile \ -f Dockerfile \
-t ollama/ollama:$VERSION \ -t ollama/ollama:$VERSION \
. .
docker build \
--load \
--platform=linux/amd64 \
--build-arg=VERSION \
--build-arg=GOFLAGS \
--target runtime-rocm \
-f Dockerfile \
-t ollama/ollama:$VERSION-rocm \
.
...@@ -25,6 +25,11 @@ import ( ...@@ -25,6 +25,11 @@ import (
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
) )
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var blobDownloadManager sync.Map var blobDownloadManager sync.Map
type blobDownload struct { type blobDownload struct {
...@@ -48,6 +53,7 @@ type blobDownloadPart struct { ...@@ -48,6 +53,7 @@ type blobDownloadPart struct {
Offset int64 Offset int64
Size int64 Size int64
Completed int64 Completed int64
lastUpdated time.Time
*blobDownload `json:"-"` *blobDownload `json:"-"`
} }
...@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 { ...@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size return p.Offset + p.Size
} }
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdated = time.Now()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*") partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil { if err != nil {
...@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis ...@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
return err return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil: case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
...@@ -195,6 +211,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis ...@@ -195,6 +211,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
...@@ -203,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w ...@@ -203,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b)) n, err := io.Copy(w, io.TeeReader(resp.Body, part))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
...@@ -217,6 +235,30 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w ...@@ -217,6 +235,30 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
// return nil or context.Canceled or UnexpectedEOF (resumable) // return nil or context.Canceled or UnexpectedEOF (resumable)
return err return err
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed >= part.Size {
return nil
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
} }
func (b *blobDownload) newPart(offset, size int64) error { func (b *blobDownload) newPart(offset, size int64) error {
...@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error ...@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return json.NewEncoder(partFile).Encode(part) return json.NewEncoder(partFile).Encode(part)
} }
func (b *blobDownload) Write(p []byte) (n int, err error) {
n = len(p)
b.Completed.Add(int64(n))
return n, nil
}
func (b *blobDownload) acquire() { func (b *blobDownload) acquire() {
b.references.Add(1) b.references.Add(1)
} }
...@@ -279,10 +315,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ...@@ -279,10 +315,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
case <-ctx.Done():
return ctx.Err()
}
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest, Digest: b.Digest,
...@@ -293,6 +325,9 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ...@@ -293,6 +325,9 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
if b.done || b.err != nil { if b.done || b.err != nil {
return b.err return b.err
} }
case <-ctx.Done():
return ctx.Err()
}
} }
} }
...@@ -303,10 +338,6 @@ type downloadOpts struct { ...@@ -303,10 +338,6 @@ type downloadOpts struct {
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error { func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)
......
...@@ -41,7 +41,7 @@ type Model struct { ...@@ -41,7 +41,7 @@ type Model struct {
Config ConfigV2 Config ConfigV2
ShortName string ShortName string
ModelPath string ModelPath string
OriginalModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string Template string
...@@ -50,6 +50,12 @@ type Model struct { ...@@ -50,6 +50,12 @@ type Model struct {
Digest string Digest string
Size int64 Size int64
Options map[string]interface{} Options map[string]interface{}
Messages []Message
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
} }
type PromptVars struct { type PromptVars struct {
...@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) { ...@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
model.OriginalModel = layer.From model.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
...@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) { ...@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
...@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars ...@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
var layers Layers var layers Layers
messages := []string{}
params := make(map[string][]string) params := make(map[string][]string)
fromParams := make(map[string]any) fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
...@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars ...@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
layers.Replace(layer) layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
default: default:
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
...@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) { ...@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
mt.Model = model mt.Model = model
mt.From = model.ModelPath mt.From = model.ModelPath
if model.OriginalModel != "" { if model.ParentModel != "" {
mt.From = model.OriginalModel mt.From = model.ParentModel
} }
modelFile := `# Modelfile generated by "ollama show" modelFile := `# Modelfile generated by "ollama show"
......
...@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) { ...@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) { ...@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
modelDetails := api.ModelDetails{ modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat, Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily, Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies, Families: model.Config.ModelFamilies,
...@@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model.Template = req.Template model.Template = req.Template
} }
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"), License: strings.Join(model.License, "\n"),
System: model.System, System: model.System,
Template: model.Template, Template: model.Template,
Details: modelDetails, Details: modelDetails,
Messages: msgs,
} }
var params []string var params []string
...@@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) { ...@@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) { ...@@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return return
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment