Unverified Commit 348b3e09 authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

server/internal: copy bmizerany/ollama-go to internal package (#9294)

This commit copies (without history) the bmizerany/ollama-go repository
with the intention of integrating it into the ollama as a replacement
for the pushing, and pulling of models, and management of the cache they
are pushed and pulled from.

New homes for these packages will be determined as they are integrated
and we have a better understanding of proper package boundaries.
parent 0b7e1676
...@@ -147,6 +147,7 @@ jobs: ...@@ -147,6 +147,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
CGO_ENABLED: '1' CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
.swp .swp
dist dist
build build
ollama
.cache .cache
*.exe *.exe
.idea .idea
...@@ -14,3 +13,4 @@ test_data ...@@ -14,3 +13,4 @@ test_data
__debug_bin* __debug_bin*
llama/build llama/build
llama/vendor llama/vendor
/ollama
...@@ -6,8 +6,6 @@ linters: ...@@ -6,8 +6,6 @@ linters:
- bidichk - bidichk
- bodyclose - bodyclose
- containedctx - containedctx
- contextcheck
- errcheck
- gocheckcompilerdirectives - gocheckcompilerdirectives
- gofmt - gofmt
- gofumpt - gofumpt
...@@ -23,10 +21,11 @@ linters: ...@@ -23,10 +21,11 @@ linters:
- staticcheck - staticcheck
- tenv - tenv
- unconvert - unconvert
- unused
- usestdlibvars
- wastedassign - wastedassign
- whitespace - whitespace
disable:
- usestdlibvars
- errcheck
linters-settings: linters-settings:
staticcheck: staticcheck:
checks: checks:
...@@ -39,5 +38,4 @@ severity: ...@@ -39,5 +38,4 @@ severity:
- gofmt - gofmt
- goimports - goimports
- intrange - intrange
- usestdlibvars
severity: info severity: info
// Package blob implements a content-addressable disk cache for blobs and
// manifests.
package blob
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/server/internal/internal/names"
)
// Entry contains metadata about a blob in the cache.
type Entry struct {
Digest Digest
Size int64
Time time.Time // when added to the cache
}
// DiskCache caches blobs and manifests on disk.
//
// The cache is rooted at a directory, which is created if it does not exist.
//
// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
// "manifests" subdirectory. A example directory structure might look like:
//
// <dir>/
// blobs/
// sha256-<digest> - <blob data>
// manifests/
// <host>/
// <namespace>/
// <name>/
// <tag> - <manifest data>
//
// The cache is safe for concurrent use.
//
// Name casing is preserved in the cache, but is not significant when resolving
// names. For example, "Foo" and "foo" are considered the same name.
//
// The cache is not safe for concurrent use. It guards concurrent writes, but
// does not prevent duplicated effort. Because blobs are immutable, duplicate
// writes should result in the same file being written to disk.
type DiskCache struct {
// Dir specifies the top-level directory where blobs and manifest
// pointers are stored.
dir string
now func() time.Time
testHookBeforeFinalWrite func(f *os.File)
}
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}
// Open opens a cache rooted at the given directory. If the directory does not
// exist, it is created. If the directory is not a directory, an error is
// returned.
func Open(dir string) (*DiskCache, error) {
if dir == "" {
return nil, errors.New("blob: empty directory name")
}
info, err := os.Stat(dir)
if err == nil && !info.IsDir() {
return nil, fmt.Errorf("%q is not a directory", dir)
}
if err := os.MkdirAll(dir, 0o777); err != nil {
return nil, err
}
subdirs := []string{"blobs", "manifests"}
for _, subdir := range subdirs {
if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
return nil, err
}
}
// TODO(bmizerany): support shards
c := &DiskCache{
dir: dir,
now: time.Now,
}
return c, nil
}
func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, Digest{}, err
}
defer f.Close()
h := sha256.New()
r := io.TeeReader(f, h)
data, err = io.ReadAll(io.LimitReader(r, limit))
if err != nil {
return nil, Digest{}, err
}
var d Digest
h.Sum(d.sum[:0])
return data, d, nil
}
//lint:ignore U1000 used for debugging purposes as needed in tests
var debug = false
// debugger returns a function that can be used to add a step to the error message.
// The error message will be a list of steps that were taken before the error occurred.
// The steps are added in the order they are called.
//
// To set the error message, call the returned function with an empty string.
//
//lint:ignore U1000 used for debugging purposes as needed in tests
func debugger(err *error) func(step string) {
if !debug {
return func(string) {}
}
var steps []string
return func(step string) {
if step == "" && *err != nil {
*err = fmt.Errorf("%q: %w", steps, *err)
return
}
steps = append(steps, step)
if len(steps) > 100 {
// shift hints in case of a bug that causes a lot of hints
copy(steps, steps[1:])
steps = steps[:100]
}
}
}
// Resolve resolves a name to a digest. The name is expected to
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
//
// If a name is provided for a manifest that exists in the cache, the digest
// of the manifest is returned. If there is no manifest in the cache, it
// returns [fs.ErrNotExist].
//
// To cover the case where a manifest may change without the cache knowing
// (e.g. it was reformatted or modified by hand), the manifest data read and
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
return ParseDigest(digest)
}
// We want to address manifests files by digest using Get. That requires
// them to be blobs. This cannot be directly accomplished by looking in
// the blob store because manifests can change without Ollama knowing
// (e.g. a user modifies a manifests by hand then pushes it to update
// their model). We also need to support the blob caches inherited from
// older versions of Ollama, which do not store manifests in the blob
// store, so for these cases, we need to handle adding the manifests to
// the blob store, just in time.
//
// So now we read the manifests file, hash it, and copy it to the blob
// store if it's not already there.
//
// This should be cheap because manifests are small, and accessed
// infrequently.
file, err := c.manifestPath(name)
if err != nil {
return Digest{}, err
}
data, d, err := readAndSum(file, 1<<20)
if err != nil {
return Digest{}, err
}
// Ideally we'd read the "manifest" file as a manifest to the blob file,
// but we are not changing this yet, so copy the manifest to the blob
// store so it can be addressed by digest subsequent calls to Get.
if err := PutBytes(c, d, data); err != nil {
return Digest{}, err
}
return d, nil
}
// Put writes a new blob to the cache, identified by its digest. The operation
// reads content from r, which must precisely match both the specified size and
// digest.
//
// Concurrent write safety is achieved through file locking. The implementation
// guarantees write integrity by enforcing size limits and content validation
// before allowing the file to reach its final state.
func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
return c.copyNamedFile(c.GetFile(d), r, d, size)
}
// Import imports a blob from the provided reader into the cache. It reads the
// entire content of the reader, calculates its digest, and stores it in the
// cache.
//
// Import should be considered unsafe for use with untrusted data, such as data
// read from a network. The caller is responsible for ensuring the integrity of
// the data being imported.
func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
// users that want to change the temp dir can set TEMPDIR.
f, err := os.CreateTemp("", "blob-")
if err != nil {
return Digest{}, err
}
defer os.Remove(f.Name())
// Copy the blob to a temporary file.
h := sha256.New()
r = io.TeeReader(r, h)
n, err := io.Copy(f, r)
if err != nil {
return Digest{}, err
}
if n != size {
return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
}
// Check the digest.
var d Digest
h.Sum(d.sum[:0])
if err := f.Close(); err != nil {
return Digest{}, err
}
name := c.GetFile(d)
// Rename the temporary file to the final file.
if err := os.Rename(f.Name(), name); err != nil {
return Digest{}, err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return d, nil
}
// Get retrieves a blob from the cache using the provided digest. The operation
// fails if the digest is malformed or if any errors occur during blob
// retrieval.
func (c *DiskCache) Get(d Digest) (Entry, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err != nil {
return Entry{}, err
}
if info.Size() == 0 {
return Entry{}, fs.ErrNotExist
}
return Entry{
Digest: d,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
// Link creates a symbolic reference in the cache that maps the provided name
// to a blob identified by its digest, making it retrievable by name using
// [Resolve].
//
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
if err != nil {
return err
}
defer f.Close()
// TODO(bmizerany): test this happens only if the blob was found to
// avoid leaving debris
if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
return err
}
info, err := f.Stat()
if err != nil {
return err
}
// Copy manifest to cache directory.
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink removes the any link for name. If the link does not exist, nothing
// happens, and no error is returned.
//
// It returns an error if the name is invalid or if the link removal encounters
// any issues.
func (c *DiskCache) Unlink(name string) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
// GetFile returns the absolute path to the file, in the cache, for the given
// digest. It does not check if the file exists.
//
// The returned path should not be stored, used outside the lifetime of the
// cache, or interpreted in any way.
func (c *DiskCache) GetFile(d Digest) string {
filename := fmt.Sprintf("sha256-%x", d.sum)
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
}
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
// path is assumed to be in filepath.ToSlash format.
func pathToName(s string) string {
s = strings.TrimPrefix(s, "manifests/")
rr := []rune(s)
for i := len(rr) - 1; i > 0; i-- {
if rr[i] == '/' {
rr[i] = ':'
return string(rr)
}
}
return s
}
// manifestPath finds the first manifest file on disk that matches the given
// name using a case-insensitive comparison. If no manifest file is found, it
// returns the path where the manifest file would be if it existed.
//
// If two manifest files exists on disk that match the given name using a
// case-insensitive comparison, the one that sorts first, lexically, is
// returned.
func (c *DiskCache) manifestPath(name string) (string, error) {
np, err := nameToPath(name)
if err != nil {
return "", err
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
}
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
f *os.File
err error
testHookBeforeFinalWrite func(*os.File)
}
func (w *checkWriter) seterr(err error) error {
if w.err == nil {
w.err = err
}
return err
}
// Write writes p to the underlying hash and writer. The last write to the
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
}
nextSize := w.n + int64(len(p))
if nextSize == w.size {
// last write. check hash.
sum := w.h.Sum(nil)
if !bytes.Equal(sum, w.d.sum[:]) {
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
}
if w.testHookBeforeFinalWrite != nil {
w.testHookBeforeFinalWrite(w.f)
}
}
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
// copyNamedFile copies file into name, expecting it to have the given Digest
// and size, if that file is not present already.
func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// File already exists with correct size. This is good enough.
// We can skip expensive hash checks.
//
// TODO: Do the hash check, but give caller a way to skip it.
return nil
}
// Copy file to cache directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0o666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
if err != nil {
f.Truncate(0)
return err
}
if n < size {
f.Truncate(0)
return io.ErrUnexpectedEOF
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove cached file.
os.Remove(name)
return err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return nil
}
func splitNameDigest(s string) (name, digest string) {
i := strings.LastIndexByte(s, '@')
if i < 0 {
return s, ""
}
return s[:i], s[i+1:]
}
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
}
return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
}
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
}
return abs
}
package blob
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func init() {
debug = true
}
var epoch = func() time.Time {
d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
if d.IsZero() {
panic("time zero")
}
return d
}()
func TestOpenErrors(t *testing.T) {
exe, err := os.Executable()
if err != nil {
panic(err)
}
cases := []struct {
dir string
err string
}{
{t.TempDir(), ""},
{"", "empty directory name"},
{exe, "not a directory"},
}
for _, tt := range cases {
t.Run(tt.dir, func(t *testing.T) {
_, err := Open(tt.dir)
if tt.err == "" {
if err != nil {
t.Fatal(err)
}
return
}
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), tt.err) {
t.Fatalf("err = %v, want %q", err, tt.err)
}
})
}
}
func TestGetFile(t *testing.T) {
t.Chdir(t.TempDir())
c, err := Open(".")
if err != nil {
t.Fatal(err)
}
d := mkdigest("1")
got := c.GetFile(d)
cleaned := filepath.Clean(got)
if cleaned != got {
t.Fatalf("got is unclean: %q", got)
}
if !filepath.IsAbs(got) {
t.Fatal("got is not absolute")
}
abs, _ := filepath.Abs(c.dir)
if !strings.HasPrefix(got, abs) {
t.Fatalf("got is not local to %q", c.dir)
}
}
func TestBasic(t *testing.T) {
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
checkEntry := entryChecker(t, c)
checkFailed := func(err error) {
if err == nil {
t.Helper()
t.Fatal("expected error")
}
}
_, err = c.Resolve("invalid")
checkFailed(err)
_, err = c.Resolve("h/n/m:t")
checkFailed(err)
dx := mkdigest("x")
d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
if err != nil {
t.Fatal(err)
}
if d != dx {
t.Fatalf("d = %v, want %v", d, dx)
}
_, err = c.Get(Digest{})
checkFailed(err)
// not committed yet
_, err = c.Get(dx)
checkFailed(err)
err = PutBytes(c, dx, "!")
checkFailed(err)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
checkEntry(dx, 1, now)
t0 := now
now = now.Add(1*time.Hour + 1*time.Minute)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
// check not updated
checkEntry(dx, 1, t0)
}
type sleepFunc func(d time.Duration) time.Time
func openTester(t *testing.T) (*DiskCache, sleepFunc) {
t.Helper()
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
return c, func(d time.Duration) time.Time {
now = now.Add(d)
return now
}
}
func TestManifestPath(t *testing.T) {
check := testutil.Checker(t)
c, sleep := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
t0 := sleep(0)
sleep(1 * time.Hour)
err = c.Link("h/n/m:t", d1) // nop expected
check(err)
file := must(c.manifestPath("h/n/m:t"))
info, err := os.Stat(file)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestExistsWithoutBlob(t *testing.T) {
t.Chdir(t.TempDir())
check := testutil.Checker(t)
c, err := Open(".")
check(err)
checkEntry := entryChecker(t, c)
man := must(c.manifestPath("h/n/m:t"))
os.MkdirAll(filepath.Dir(man), 0o777)
testutil.WriteFile(t, man, "1")
got, err := c.Resolve("h/n/m:t")
check(err)
want := mkdigest("1")
if got != want {
t.Fatalf("got = %v, want %v", got, want)
}
e, err := c.Get(got)
check(err)
checkEntry(got, 1, e.Time)
}
func TestPut(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("hello, world")
err := PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
got, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("expected error, got %v", got)
}
// Put a valid blob
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with content that does not hash to the digest
err = PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put the valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob that errors during Read
err = c.Put(d, &errOnBangReader{s: "!"}, 1)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with mismatched size
err = c.Put(d, strings.NewReader("hello, world"), 11)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Final byte does not match the digest (testing commit phase)
err = PutBytes(c, d, "hello, world$")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
// change mode to read-only
f.Truncate(0)
f.Chmod(0o400)
f.Close()
f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { f1.Close() })
*f = *f1
})
defer reset()
err = PutBytes(c, d, "hello, world")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset()
}
func TestImport(t *testing.T) {
c, _ := openTester(t)
checkEntry := entryChecker(t, c)
want := mkdigest("x")
got, err := c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
got, err = c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
}
func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
old := c.testHookBeforeFinalWrite
c.testHookBeforeFinalWrite = h
return func() { c.testHookBeforeFinalWrite = old }
}
func TestPutGetZero(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("x")
err := PutBytes(c, d, "x")
check(err)
checkEntry(d, 1, sleep(0))
err = os.Truncate(c.GetFile(d), 0)
check(err)
_, err = c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func TestPutZero(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("x")
err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
testutil.Check(t, err)
checkNotExists(t, c, d)
}
func TestCommit(t *testing.T) {
check := testutil.Checker(t)
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
checkEntry := entryChecker(t, c)
now := epoch
c.now = func() time.Time { return now }
d1 := mkdigest("1")
err = c.Link("h/n/m:t", d1)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
err = PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
got, err := c.Resolve("h/n/m:t")
check(err)
if got != d1 {
t.Fatalf("d = %v, want %v", got, d1)
}
// commit again, more than 1 byte
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("h/n/m:t", d2)
check(err)
checkEntry(d2, 2, now)
filename := must(c.manifestPath("h/n/m:t"))
data, err := os.ReadFile(filename)
check(err)
if string(data) != "22" {
t.Fatalf("data = %q, want %q", data, "22")
}
t0 := now
now = now.Add(1 * time.Hour)
err = c.Link("h/n/m:t", d2) // same contents; nop
check(err)
info, err := os.Stat(filename)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestInvalidBlob(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("1")
err := c.Link("h/n/m:t", d)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
err = PutBytes(c, d, "1")
testutil.Check(t, err)
err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
if err != nil {
t.Fatal(err)
}
err = c.Link("h/n/m:t", d)
if !strings.Contains(err.Error(), "underfoot") {
t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
}
}
func TestManifestNameReuse(t *testing.T) {
t.Run("case-insensitive", func(t *testing.T) {
// This should run on all file system types.
testManifestNameReuse(t)
})
t.Run("case-sensitive", func(t *testing.T) {
useCaseInsensitiveTempDir(t)
testManifestNameReuse(t)
})
}
func testManifestNameReuse(t *testing.T) {
check := testutil.Checker(t)
c, _ := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("H/N/M:T", d2)
check(err)
var g [2]Digest
g[0], err = c.Resolve("h/n/m:t")
check(err)
g[1], err = c.Resolve("H/N/M:T")
check(err)
w := [2]Digest{d2, d2}
if g != w {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
// relink with different case
err = c.Unlink("h/n/m:t")
check(err)
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
// we should have only one link that is same case as the last link
want = []string{"manifests/h/n/m/T"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func TestManifestFile(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""},
// valid names
{"h/n/m:t", "/manifests/h/n/m/t"},
{"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
{"%/%/%/%", ""},
// already a path
{"h/n/m/t", ""},
// refs are not names
{"h/n/m:t@sha256-1", ""},
{"m@sha256-1", ""},
{"n/m:t@sha256-1", ""},
}
c, _ := openTester(t)
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := c.manifestPath(tt.in)
if err != nil && tt.want != "" {
t.Fatalf("unexpected error: %v", err)
}
if err == nil && tt.want == "" {
t.Fatalf("expected error")
}
dir := filepath.ToSlash(c.dir)
got = filepath.ToSlash(got)
got = strings.TrimPrefix(got, dir)
if got != tt.want {
t.Fatalf("got = %q, want %q", got, tt.want)
}
})
}
}
func TestNames(t *testing.T) {
c, _ := openTester(t)
check := testutil.Checker(t)
check(PutBytes(c, mkdigest("1"), "1"))
check(PutBytes(c, mkdigest("2"), "2"))
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func mkdigest(s string) Digest {
return Digest{sha256.Sum256([]byte(s))}
}
func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
t.Helper()
_, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
t.Helper()
return func(d Digest, size int64, mod time.Time) {
t.Helper()
t.Run("checkEntry:"+d.String(), func(t *testing.T) {
t.Helper()
defer func() {
if t.Failed() {
dumpCacheContents(t, c)
}
}()
e, err := c.Get(d)
if size == 0 && errors.Is(err, fs.ErrNotExist) {
err = nil
}
if err != nil {
t.Fatal(err)
}
if e.Digest != d {
t.Errorf("e.Digest = %v, want %v", e.Digest, d)
}
if e.Size != size {
t.Fatalf("e.Size = %v, want %v", e.Size, size)
}
testutil.CheckTime(t, e.Time, mod)
info, err := os.Stat(c.GetFile(d))
if err != nil {
t.Fatal(err)
}
if info.Size() != size {
t.Fatalf("info.Size = %v, want %v", info.Size(), size)
}
testutil.CheckTime(t, info.ModTime(), mod)
})
}
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}
func TestNameToPath(t *testing.T) {
_, err := nameToPath("h/n/m:t")
if err != nil {
t.Fatal(err)
}
}
type errOnBangReader struct {
s string
n int
}
func (e *errOnBangReader) Read(p []byte) (int, error) {
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
if e.n >= len(p) {
return 0, io.EOF
}
if e.s[e.n] == '!' {
return 0, errors.New("bang")
}
p[0] = e.s[e.n]
e.n++
return 1, nil
}
func dumpCacheContents(t *testing.T, c *DiskCache) {
t.Helper()
var b strings.Builder
fsys := os.DirFS(c.dir)
fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
t.Helper()
if err != nil {
return err
}
info, err := d.Info()
if err != nil {
return err
}
// Format like ls:
//
// ; ls -la
// drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
// drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
fmt.Fprintf(&b, " %s % 4d %s %s\n",
info.Mode(),
info.Size(),
info.ModTime().Format("Jan 2 15:04"),
path,
)
return nil
})
t.Log()
t.Logf("cache contents:\n%s", b.String())
}
package blob
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func isCaseSensitive(dir string) bool {
defer func() {
os.Remove(filepath.Join(dir, "_casecheck"))
}()
exists := func(file string) bool {
_, err := os.Stat(file)
return err == nil
}
file := filepath.Join(dir, "_casecheck")
FILE := filepath.Join(dir, "_CASECHECK")
if exists(file) || exists(FILE) {
panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
}
err := os.WriteFile(file, nil, 0o666)
if err != nil {
panic(err)
}
return !exists(FILE)
}
func isCI() bool {
return os.Getenv("CI") != ""
}
const volumeHint = `
Unable to locate case-insensitive TMPDIR on darwin.
To run tests, create the case-insensitive volume /Volumes/data:
$ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
or run with:
CI=1 go test ./...
`
// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
// can find one, otherwise it skips the test if the CI environment variable is
// set, or GOOS is not darwin.
func useCaseInsensitiveTempDir(t *testing.T) bool {
if isCaseSensitive(os.TempDir()) {
// Use the default temp dir if it is already case-sensitive.
return true
}
if runtime.GOOS == "darwin" {
// If darwin, check for the special case-sensitive volume and
// use it if available.
const volume = "/Volumes/data"
_, err := os.Stat(volume)
if err == nil {
tmpdir := filepath.Join(volume, "tmp")
os.MkdirAll(tmpdir, 0o700)
t.Setenv("TMPDIR", tmpdir)
return true
}
if isCI() {
// Special case darwin in CI; it is not case-sensitive
// by default, and we will be testing other platforms
// that are case-sensitive, so we'll have the test
// being skipped covered there.
t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
}
}
if !isCI() {
// Require devs to always tests with a case-insensitive TMPDIR.
// TODO(bmizerany): Print platform-specific instructions or
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Log(line)
}
}
return false
}
package blob
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"strings"
)
var ErrInvalidDigest = errors.New("invalid digest")
// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
//
// It is comparable and can be used as a map key.
type Digest struct {
sum [32]byte
}
// ParseDigest parses a digest from a string. If the string is not a valid
// digest, a call to the returned digest's IsValid method will return false.
//
// The input string may be in one of two forms:
//
// - ("sha256-<hex>"), where <hex> is a 64-character hexadecimal string.
// - ("sha256:<hex>"), where <hex> is a 64-character hexadecimal string.
//
// The [Digest.String] method will return the canonical form of the
// digest, "sha256:<hex>".
func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
s := string(v)
i := strings.IndexAny(s, ":-")
var zero Digest
if i < 0 {
return zero, ErrInvalidDigest
}
prefix, sum := s[:i], s[i+1:]
if prefix != "sha256" || len(sum) != 64 {
return zero, ErrInvalidDigest
}
var d Digest
_, err := hex.Decode(d.sum[:], []byte(sum))
if err != nil {
return zero, ErrInvalidDigest
}
return d, nil
}
func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
return Digest{sha256.Sum256([]byte(v))}
}
// String returns the string representation of the digest in the conventional
// form "sha256:<hex>".
func (d Digest) String() string {
return fmt.Sprintf("sha256:%x", d.sum[:])
}
func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}
// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
// of some content.
func (d Digest) IsValid() bool {
return d != (Digest{})
}
// MarshalText implements the encoding.TextMarshaler interface. It returns an
// error if [Digest.IsValid] returns false.
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
// works for a zero digest. If [Digest.IsValid] returns true, it returns an
// error.
func (d *Digest) UnmarshalText(text []byte) error {
if *d != (Digest{}) {
return errors.New("digest: illegal UnmarshalText on valid digest")
}
v, err := ParseDigest(string(text))
if err != nil {
return err
}
*d = v
return nil
}
package blob
import (
"encoding/json"
"testing"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
valid bool
}{
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
// too short
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
// too long
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
// invalid prefix
{"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
// invalid hex
{"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
{"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
}
for _, tt := range cases {
got, err := ParseDigest(tt.in)
if tt.valid && err != nil {
t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
}
want := "sha256:" + tt.in[7:]
if tt.valid && got.String() != want {
t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
}
}
}
func TestDigestMarshalText(t *testing.T) {
const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
var d Digest
if err := json.Unmarshal([]byte(s), &d); err != nil {
t.Errorf("json.Unmarshal: %v", err)
}
out, err := json.Marshal(d)
if err != nil {
t.Errorf("json.Marshal: %v", err)
}
want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
if string(out) != want {
t.Errorf("json.Marshal: got %s, want %s", out, want)
}
if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
t.Errorf("json.Unmarshal: expected error")
}
}
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}
This diff is collapsed.
package ollama
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
data, err := json.Marshal(m)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
t.Error("expected manifest to contain empty config")
t.Fatalf("got:\n%s", string(data))
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
rr(w, req)
if w.Code == 499 {
return nil, errRoundTrip
}
resp := w.Result()
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
// set it manually.
resp.Request = req
return w.Result(), nil
}
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
mklayer := func(data string) *Layer {
return &Layer{
Digest: importBytes(t, c, data),
Size: int64(len(data)),
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
}
link(c, "empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
d, err := c.Import(strings.NewReader(data), int64(len(data)))
if err != nil {
t.Fatal(err)
}
return d
}
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), c, "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
pushed = true
w.Header().Set("Location", "http://blob.store/blobs/123")
return
}
w.WriteHeader(http.StatusAccepted)
return
}
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
rc.MaxStreams = 1 // prevent concurrent uploads
var errs []error
ctx := WithTrace(t.Context(), &Trace{
Update: func(_ *Layer, n int64, err error) {
// uploading one at a time so no need to lock
errs = append(errs, err)
},
})
check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, c, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), c, "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), c, "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
}
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), c, "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
}
func TestPushUploadFileOpenError(t *testing.T) {
rc, c := newClient(t, okHandler)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, _ int64, err error) {
// Remove the file just before it is opened for upload,
// but after the initial Stat that happens before the
// upload starts
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, c, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), c, "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
}
func TestRegistryPullInvalidManifest(t *testing.T) {
cases := []string{
"",
"null",
"!!!",
`{"layers":[]}`,
}
for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), c, "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, c, "single")
testutil.Check(t, err)
want := []int64{6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), c, "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), c, "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), c, "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, c, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
exists := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
w.WriteHeader(499) // should not hit manifest endpoint
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
})
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
check(err)
}
func TestInsecureSkipVerify(t *testing.T) {
exists := blob.DigestFromBytes("exists")
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
}))
defer s.Close()
const name = "ollama.com/library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verifiction failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
_, err = rc.Resolve(t.Context(), url)
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}
package ollama
import (
"context"
)
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
//
// It is called once at the beginning of the download with a zero n and
// then once per read operation with the number of bytes read so far,
// and an error if any.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
Update func(_ *Layer, n int64, _ error)
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}
// safetensors provides a reader for the safetensor directories and files.
package safetensors
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"slices"
"strconv"
"strings"
)
// Tensor represents a single tensor in a safetensors file.
//
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
//
// It is not safe for use across multiple goroutines.
type Tensor struct {
name string
dataType string
shape []int64
fsys fs.FS
fname string // entry name in fsys
offset int64
size int64
}
type Model struct {
fsys fs.FS
}
func Read(fsys fs.FS) (*Model, error) {
return &Model{fsys: fsys}, nil
}
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
return func(yield func(*Tensor, error) bool) {
entries, err := fs.Glob(m.fsys, "*.safetensors")
if err != nil {
yield(nil, err)
return
}
for _, e := range entries {
tt, err := m.readTensors(e)
if err != nil {
yield(nil, err)
return
}
for _, t := range tt {
if !yield(t, nil) {
return
}
}
}
}
}
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
f, err := m.fsys.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
finfo, err := f.Stat()
if err != nil {
return nil, err
}
headerSize, err := readInt64(f)
if err != nil {
return nil, err
}
data := make([]byte, headerSize)
_, err = io.ReadFull(f, data)
if err != nil {
return nil, err
}
var raws map[string]json.RawMessage
if err := json.Unmarshal(data, &raws); err != nil {
return nil, err
}
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
// own and if it needs to follow each tensor. Currently, I (bmizerany)
// am only seeing them show up with one entry for file type which is
// always "pt".
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
continue
}
var v struct {
DataType string `json:"dtype"`
Shape []int64 `json:"shape"`
Offsets []int64 `json:"data_offsets"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
}
if len(v.Offsets) != 2 {
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
}
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
tt = append(tt, &Tensor{
name: name,
dataType: v.DataType,
shape: v.Shape,
fsys: m.fsys,
fname: fname,
offset: begin,
size: end - begin,
})
}
return tt, nil
}
func checkBeginEnd(size, begin, end int64) error {
if begin < 0 {
return fmt.Errorf("begin must not be negative: %d", begin)
}
if end < 0 {
return fmt.Errorf("end must not be negative: %d", end)
}
if end < begin {
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
}
if end > size {
return fmt.Errorf("end must be <= size: %d > %d", end, size)
}
return nil
}
func readInt64(r io.Reader) (int64, error) {
var v uint64
var buf [8]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return 0, err
}
for i := range buf {
v |= uint64(buf[i]) << (8 * i)
}
return int64(v), nil
}
type Shape []int64
func (s Shape) String() string {
var b strings.Builder
b.WriteByte('[')
for i, v := range s {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(strconv.FormatInt(v, 10))
}
b.WriteByte(']')
return b.String()
}
func (t *Tensor) Name() string { return t.name }
func (t *Tensor) DataType() string { return t.dataType }
func (t *Tensor) Size() int64 { return t.size }
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
func (t *Tensor) Reader() (io.ReadCloser, error) {
f, err := t.fsys.Open(t.fname)
if err != nil {
return nil, err
}
r := newSectionReader(f, t.offset, t.size)
rc := struct {
io.Reader
io.Closer
}{r, f}
return rc, nil
}
// newSectionReader returns a new io.Reader that reads from r starting at
// offset. It is a convenience function for creating a io.SectionReader when r
// may not be an io.ReaderAt.
//
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
// beginning of the file.
//
// If r is an io.Seeker,
// or slow path. The slow path is used when r does not implement io.ReaderAt,
// in which case it must discard the data it reads.
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
if r, ok := r.(io.ReaderAt); ok {
return io.NewSectionReader(r, offset, n)
}
if r, ok := r.(io.ReadSeeker); ok {
r.Seek(offset, io.SeekStart)
return io.LimitReader(r, n)
}
// Discard to offset and return a limited reader.
_, err := io.CopyN(io.Discard, r, offset)
if err != nil {
return nil
}
return io.LimitReader(r, n)
}
package main
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
"golang.org/x/sync/errgroup"
)
var stdout io.Writer = os.Stdout
const usage = `Opp is a tool for pushing and pulling Ollama models.
Usage:
opp [flags] <push|pull|import>
Commands:
push Upload a model to the Ollama server.
pull Download a model from the Ollama server.
import Import a model from a local safetensor directory.
Examples:
# Pull a model from the Ollama server.
opp pull library/llama3.2:latest
# Push a model to the Ollama server.
opp push username/my_model:8b
# Import a model from a local safetensor directory.
opp import /path/to/safetensor
Envionment Variables:
OLLAMA_MODELS
The directory where models are pushed and pulled from
(default ~/.ollama/models).
`
func main() {
flag.Usage = func() {
fmt.Fprint(os.Stderr, usage)
}
flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
err = func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
return cmdPull(ctx, rc, c)
case "push":
return cmdPush(ctx, rc, c)
case "import":
return cmdImport(ctx, c)
default:
if cmd == "" {
flag.Usage()
} else {
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
}
os.Exit(2)
return errors.New("unreachable")
}
}()
if err != nil {
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
os.Exit(1)
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
os.Exit(1)
}
tr := http.DefaultTransport.(*http.Transport).Clone()
// TODO(bmizerany): configure transport?
rc.HTTPClient = &http.Client{Transport: tr}
var mu sync.Mutex
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
var pb bytes.Buffer
printProgress := func() {
pb.Reset()
mu.Lock()
for d, s := range p {
// Write progress to a buffer first to avoid blocking
// on stdout while holding the lock.
stamp := time.Now().Format("2006/01/02 15:04:05")
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
if s[0] == s[1] {
delete(p, d)
}
}
mu.Unlock()
io.Copy(stdout, &pb)
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
return
}
mu.Lock()
p[l.Digest] = [2]int64{l.Size, n}
mu.Unlock()
},
})
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, c, model)
}()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
printProgress()
case err := <-errc:
printProgress()
return err
}
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
flag.PrintDefaults()
}
flag.Parse(args)
model := flag.Arg(0)
if model == "" {
return fmt.Errorf("missing model argument")
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
if err != nil {
return err
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
switch {
case errors.Is(err, ollama.ErrCached):
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
case err != nil:
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
case n == 0:
l := m.Layer(l.Digest)
mt, p, _ := mime.ParseMediaType(l.MediaType)
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
switch mt {
case "tensor":
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
default:
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
}
}
},
})
return rc.Push(ctx, c, model, &ollama.PushParams{
From: from,
})
}
type trackingReader struct {
io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.n.Add(int64(n))
return n, err
}
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("import", flag.ExitOnError)
flagAs := flag.String("as", "", "Import using the provided name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
flag.PrintDefaults()
}
flag.Parse(args)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
m, err := safetensors.Read(os.DirFS(dir))
if err != nil {
return err
}
var total int64
var tt []*safetensors.Tensor
for t, err := range m.Tensors() {
if err != nil {
return err
}
tt = append(tt, t)
total += t.Size()
}
var n atomic.Int64
done := make(chan error)
go func() {
layers := make([]*ollama.Layer, len(tt))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
var ctxErr error
for i, t := range tt {
if ctx.Err() != nil {
// The context may cancel AFTER we exit the
// loop, and so if we use ctx.Err() after the
// loop we may report it as the error that
// broke the loop, when it was not. This can
// manifest as a false-negative, leading the
// user to think their import failed when it
// did not, so capture it if and only if we
// exit the loop because of a ctx.Err() and
// report it.
ctxErr = ctx.Err()
break
}
g.Go(func() (err error) {
rc, err := t.Reader()
if err != nil {
return err
}
defer rc.Close()
tr := &trackingReader{rc, &n}
d, err := c.Import(tr, t.Size())
if err != nil {
return err
}
if err := rc.Close(); err != nil {
return err
}
layers[i] = &ollama.Layer{
Digest: d,
Size: t.Size(),
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
"name": t.Name(),
"dtype": t.DataType(),
"shape": t.Shape().String(),
}),
}
return nil
})
}
done <- func() error {
if err := errors.Join(g.Wait(), ctxErr); err != nil {
return err
}
m := &ollama.Manifest{Layers: layers}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return err
}
d := blob.DigestFromBytes(data)
err = blob.PutBytes(c, d, data)
if err != nil {
return err
}
return c.Link(*flagAs, d)
}()
}()
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
csiHideCursor(stdout)
defer csiShowCursor(stdout)
csiSavePos(stdout)
writeProgress := func() {
csiRestorePos(stdout)
nn := n.Load()
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
formatNatural(nn),
formatNatural(total),
nn*100/total,
ansiClearToEndOfLine,
)
}
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
writeProgress()
case err := <-done:
writeProgress()
return err
}
}
}
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d B", n)
case n < 1024*1024:
return fmt.Sprintf("%.1f KB", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
}
}
const ansiClearToEndOfLine = "\033[K"
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}
package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
if !yield(n, nil) {
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
}
}
}
//go:build goexperiment.synctest
package backoff
import (
"context"
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoop(t *testing.T) {
synctest.Run(func() {
last := -1
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
if n > 5 {
cancel()
}
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
}
})
}
package backoff
import (
"context"
"testing"
"testing/synctest"
"time"
)
func TestLoopAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
if tick >= i {
break
}
}
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
}
}
}
func BenchmarkLoop(b *testing.B) {
ctx := context.Background()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
if n == b.N {
break
}
}
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment