Unverified Commit 2fb52261 authored by Patrick Devine's avatar Patrick Devine Committed by GitHub
Browse files

basic distribution w/ push/pull (#78)



* basic distribution w/ push/pull

* add the parser

* add create, pull, and push

* changes to the parser, FROM line, and fix commands

* mkdirp new manifest directories

* make `blobs` directory if it does not exist

* fix go warnings

* add progressbar for model pulls

* move model struct

---------
Co-authored-by: default avatarJeffrey Morgan <jmorganca@gmail.com>
parent 6fdea030
...@@ -116,3 +116,29 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc ...@@ -116,3 +116,29 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
return fn(resp) return fn(resp)
}) })
} }
type PushProgressFunc func(PushProgress) error
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
var resp PushProgress
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type CreateProgressFunc func(CreateProgress) error
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp CreateProgress
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
...@@ -7,22 +7,49 @@ import ( ...@@ -7,22 +7,49 @@ import (
"time" "time"
) )
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
Options `json:"options"`
}
type CreateRequest struct {
Name string `json:"name"`
Path string `json:"path"`
}
type CreateProgress struct {
Status string `json:"status"`
}
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Name string `json:"name"`
Username string `json:"username"`
Password string `json:"password"`
} }
type PullProgress struct { type PullProgress struct {
Total int64 `json:"total"` Status string `json:"status"`
Completed int64 `json:"completed"` Digest string `json:"digest,omitempty"`
Percent float64 `json:"percent"` Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
} }
type GenerateRequest struct { type PushRequest struct {
Model string `json:"model"` Name string `json:"name"`
Prompt string `json:"prompt"` Username string `json:"username"`
Context []int `json:"context,omitempty"` Password string `json:"password"`
}
Options `json:"options"` type PushProgress struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
} }
type GenerateResponse struct { type GenerateResponse struct {
......
...@@ -30,6 +30,23 @@ func cacheDir() string { ...@@ -30,6 +30,23 @@ func cacheDir() string {
return filepath.Join(home, ".ollama") return filepath.Join(home, ".ollama")
} }
func create(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
client := api.NewClient()
request := api.CreateRequest{Name: args[0], Path: filename}
fn := func(resp api.CreateProgress) error {
fmt.Println(resp.Status)
return nil
}
if err := client.Create(context.Background(), &request, fn); err != nil {
return err
}
return nil
}
func RunRun(cmd *cobra.Command, args []string) error { func RunRun(cmd *cobra.Command, args []string) error {
_, err := os.Stat(args[0]) _, err := os.Stat(args[0])
switch { switch {
...@@ -51,25 +68,56 @@ func RunRun(cmd *cobra.Command, args []string) error { ...@@ -51,25 +68,56 @@ func RunRun(cmd *cobra.Command, args []string) error {
return RunGenerate(cmd, args) return RunGenerate(cmd, args)
} }
func push(cmd *cobra.Command, args []string) error {
client := api.NewClient()
request := api.PushRequest{Name: args[0]}
fn := func(resp api.PushProgress) error {
fmt.Println(resp.Status)
return nil
}
if err := client.Push(context.Background(), &request, fn); err != nil {
return err
}
return nil
}
func RunPull(cmd *cobra.Command, args []string) error {
return pull(args[0])
}
func pull(model string) error { func pull(model string) error {
client := api.NewClient() client := api.NewClient()
var bar *progressbar.ProgressBar var bar *progressbar.ProgressBar
return client.Pull(
context.Background(),
&api.PullRequest{Model: model},
func(progress api.PullProgress) error {
if bar == nil {
if progress.Percent >= 100 {
// already downloaded
return nil
}
bar = progressbar.DefaultBytes(progress.Total) currentLayer := ""
request := api.PullRequest{Name: model}
fn := func(resp api.PullProgress) error {
if resp.Digest != currentLayer && resp.Digest != "" {
if currentLayer != "" {
fmt.Println()
} }
currentLayer = resp.Digest
layerStr := resp.Digest[7:23] + "..."
bar = progressbar.DefaultBytes(
int64(resp.Total),
"pulling "+layerStr,
)
} else if resp.Digest == currentLayer && resp.Digest != "" {
bar.Set(resp.Completed)
} else {
currentLayer = ""
fmt.Println(resp.Status)
}
return nil
}
return bar.Set64(progress.Completed) if err := client.Pull(context.Background(), &request, fn); err != nil {
}, return err
) }
return nil
} }
func RunGenerate(cmd *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
...@@ -215,6 +263,15 @@ func NewCLI() *cobra.Command { ...@@ -215,6 +263,15 @@ func NewCLI() *cobra.Command {
cobra.EnableCommandSorting = false cobra.EnableCommandSorting = false
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model from a Modelfile",
Args: cobra.MinimumNArgs(1),
RunE: create,
}
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
runCmd := &cobra.Command{ runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]", Use: "run MODEL [PROMPT]",
Short: "Run a model", Short: "Run a model",
...@@ -231,9 +288,26 @@ func NewCLI() *cobra.Command { ...@@ -231,9 +288,26 @@ func NewCLI() *cobra.Command {
RunE: RunServer, RunE: RunServer,
} }
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.MinimumNArgs(1),
RunE: RunPull,
}
pushCmd := &cobra.Command{
Use: "push MODEL",
Short: "Push a model to a registry",
Args: cobra.MinimumNArgs(1),
RunE: push,
}
rootCmd.AddCommand( rootCmd.AddCommand(
serveCmd, serveCmd,
createCmd,
runCmd, runCmd,
pullCmd,
pushCmd,
) )
return rootCmd return rootCmd
......
package parser
import (
"bufio"
"fmt"
"io"
"strings"
)
type Command struct {
Name string
Arg string
}
func Parse(reader io.Reader) ([]Command, error) {
var commands []Command
var foundModel bool
scanner := bufio.NewScanner(reader)
multiline := false
var multilineCommand *Command
for scanner.Scan() {
line := scanner.Text()
if multiline {
// If we're in a multiline string and the line is """, end the multiline string.
if strings.TrimSpace(line) == `"""` {
multiline = false
commands = append(commands, *multilineCommand)
} else {
// Otherwise, append the line to the multiline string.
multilineCommand.Arg += "\n" + line
}
continue
}
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
command := Command{}
switch fields[0] {
case "FROM":
command.Name = "model"
command.Arg = fields[1]
if command.Arg == "" {
return nil, fmt.Errorf("no model specified in FROM line")
}
foundModel = true
case "PROMPT":
command.Name = "prompt"
if fields[1] == `"""` {
multiline = true
multilineCommand = &command
multilineCommand.Arg = ""
} else {
command.Arg = strings.Join(fields[1:], " ")
}
case "PARAMETER":
command.Name = fields[1]
command.Arg = strings.Join(fields[2:], " ")
default:
continue
}
if !multiline {
commands = append(commands, command)
}
}
if !foundModel {
return nil, fmt.Errorf("no FROM line for the model was specified")
}
if multiline {
return nil, fmt.Errorf("unclosed multiline string")
}
return commands, scanner.Err()
}
This diff is collapsed.
package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
)
const directoryURL = "https://ollama.ai/api/models"
type Model struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Parameters string `json:"parameters"`
URL string `json:"url"`
ShortDescription string `json:"short_description"`
Description string `json:"description"`
PublishedBy string `json:"published_by"`
OriginalAuthor string `json:"original_author"`
OriginalURL string `json:"original_url"`
License string `json:"license"`
}
func (m *Model) FullName() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return filepath.Join(home, ".ollama", "models", m.Name+".bin")
}
func (m *Model) TempFile() string {
fullName := m.FullName()
return filepath.Join(
filepath.Dir(fullName),
fmt.Sprintf(".%s.part", filepath.Base(fullName)),
)
}
func getRemote(model string) (*Model, error) {
// resolve the model download from our directory
resp, err := http.Get(directoryURL)
if err != nil {
return nil, fmt.Errorf("failed to get directory: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
var models []Model
err = json.Unmarshal(body, &models)
if err != nil {
return nil, fmt.Errorf("failed to parse directory: %w", err)
}
for _, m := range models {
if m.Name == model {
return &m, nil
}
}
return nil, fmt.Errorf("model not found in directory: %s", model)
}
func saveModel(model *Model, fn func(total, completed int64)) error {
// this models cache directory is created by the server on startup
client := &http.Client{}
req, err := http.NewRequest("GET", model.URL, nil)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
var size int64
// completed file doesn't exist, check partial file
fi, err := os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download model: %s", resp.Status)
}
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
panic(err)
}
defer out.Close()
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
completed := size
total := remaining + completed
for {
fn(total, completed)
if completed >= total {
return os.Rename(model.TempFile(), model.FullName())
}
n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
completed += n
}
}
package server package server
import ( import (
"embed"
"encoding/json" "encoding/json"
"errors" "fmt"
"io" "io"
"log" "log"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
...@@ -16,16 +14,11 @@ import ( ...@@ -16,16 +14,11 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
) )
//go:embed templates/*
var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func cacheDir() string { func cacheDir() string {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
...@@ -40,6 +33,7 @@ func generate(c *gin.Context) { ...@@ -40,6 +33,7 @@ func generate(c *gin.Context) {
req := api.GenerateRequest{ req := api.GenerateRequest{
Options: api.DefaultOptions(), Options: api.DefaultOptions(),
Prompt: "",
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
...@@ -47,34 +41,28 @@ func generate(c *gin.Context) { ...@@ -47,34 +41,28 @@ func generate(c *gin.Context) {
return return
} }
if remoteModel, _ := getRemote(req.Model); remoteModel != nil { model, err := GetModel(req.Model)
req.Model = remoteModel.FullName() if err != nil {
} c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
if _, err := os.Stat(req.Model); err != nil { return
if !errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Model = filepath.Join(cacheDir(), "models", req.Model+".bin")
} }
templateNames := make([]string, 0, len(templates.Templates())) templ, err := template.New("").Parse(model.Prompt)
for _, template := range templates.Templates() { if err != nil {
templateNames = append(templateNames, template.Name()) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} }
match, _ := matchRankOne(filepath.Base(req.Model), templateNames) var sb strings.Builder
if template := templates.Lookup(match); template != nil { if err = templ.Execute(&sb, req); err != nil {
var sb strings.Builder c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if err := template.Execute(&sb, req); err != nil { return
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
req.Prompt = sb.String()
} }
req.Prompt = sb.String()
llm, err := llama.New(req.Model, req.Options) fmt.Printf("prompt = >>>%s<<<\n", req.Prompt)
llm, err := llama.New(model.ModelPath, req.Options)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -105,40 +93,84 @@ func pull(c *gin.Context) { ...@@ -105,40 +93,84 @@ func pull(c *gin.Context) {
return return
} }
remote, err := getRemote(req.Model) ch := make(chan any)
if err != nil { go func() {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) {
ch <- api.PullProgress{
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
}
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}()
streamResponse(c, ch)
}
func push(c *gin.Context) {
var req api.PushRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// check if completed file exists ch := make(chan any)
fi, err := os.Stat(remote.FullName()) go func() {
switch { defer close(ch)
case errors.Is(err, os.ErrNotExist): fn := func(status, digest string, total, completed int, percent float64) {
// noop, file doesn't exist so create it ch <- api.PushProgress{
case err != nil: Status: status,
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
}
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}()
streamResponse(c, ch)
}
func create(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return return
default: }
c.JSON(http.StatusOK, api.PullProgress{
Total: fi.Size(), // NOTE consider passing the entire Modelfile in the json instead of the path to it
Completed: fi.Size(),
Percent: 100,
})
file, err := os.Open(req.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return return
} }
defer file.Close()
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
saveModel(remote, func(total, completed int64) { fn := func(status string) {
ch <- api.PullProgress{ ch <- api.CreateProgress{
Total: total, Status: status,
Completed: completed,
Percent: float64(completed) / float64(total) * 100,
} }
}) }
if err := CreateModel(req.Name, file, fn); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
}() }()
streamResponse(c, ch) streamResponse(c, ch)
...@@ -153,6 +185,8 @@ func Serve(ln net.Listener) error { ...@@ -153,6 +185,8 @@ func Serve(ln net.Listener) error {
r.POST("/api/pull", pull) r.POST("/api/pull", pull)
r.POST("/api/generate", generate) r.POST("/api/generate", generate)
r.POST("/api/create", create)
r.POST("/api/push", push)
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())
s := &http.Server{ s := &http.Server{
...@@ -162,18 +196,6 @@ func Serve(ln net.Listener) error { ...@@ -162,18 +196,6 @@ func Serve(ln net.Listener) error {
return s.Serve(ln) return s.Serve(ln)
} }
func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
bestRank = math.MaxInt
for _, target := range targets {
if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank {
bestRank = rank
bestMatch = target
}
}
return
}
func streamResponse(c *gin.Context, ch chan any) { func streamResponse(c *gin.Context, ch chan any) {
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
val, ok := <-ch val, ok := <-ch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment