"vscode:/vscode.git/clone" did not exist on "19a2f8e4ffcf0340ba9cedd1ea30356b3510ab18"
Commit f397e0e9 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by jmorganca
Browse files

Move hub auth out to new package

parent 9da9e8fb
package server
package auth
import (
"bytes"
......@@ -24,6 +24,10 @@ import (
"github.com/jmorganca/ollama/api"
)
const (
KeyType = "id_ed25519"
)
type AuthRedirect struct {
Realm string
Service string
......@@ -71,39 +75,47 @@ func (r AuthRedirect) URL() (*url.URL, error) {
return redirectURL, nil
}
func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
redirectURL, err := redirData.URL()
if err != nil {
return "", err
}
func SignRequest(method, url string, data []byte, headers http.Header) error {
home, err := os.UserHomeDir()
if err != nil {
return "", err
return err
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
keyPath := filepath.Join(home, ".ollama", KeyType)
rawKey, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
return err
}
s := SignatureData{
Method: http.MethodGet,
Path: redirectURL.String(),
Data: nil,
Method: method,
Path: url,
Data: data,
}
sig, err := s.Sign(rawKey)
if err != nil {
return err
}
headers.Set("Authorization", sig)
return nil
}
func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
redirectURL, err := redirData.URL()
if err != nil {
return "", err
}
headers := make(http.Header)
headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
if err != nil {
return "", err
}
resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
return "", err
......
package auth
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"runtime"
"strconv"
"github.com/jmorganca/ollama/version"
)
type RegistryOptions struct {
Insecure bool
Username string
Password string
Token string
}
func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
......@@ -22,6 +22,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format"
)
......@@ -85,7 +86,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
return err
......@@ -137,11 +138,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
return nil
}
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
b.err = b.run(ctx, requestURL, opts)
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
......@@ -210,7 +211,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header)
......@@ -334,7 +335,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
type downloadOpts struct {
mp ModelPath
digest string
regOpts *RegistryOptions
regOpts *auth.RegistryOptions
fn func(api.ProgressResponse)
}
......
......@@ -16,25 +16,17 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"text/template"
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
type RegistryOptions struct {
Insecure bool
Username string
Password string
Token string
}
type Model struct {
Name string `json:"name"`
Config ConfigV2
......@@ -320,7 +312,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
return err
}
......@@ -840,7 +832,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
......@@ -890,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
......@@ -996,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
......@@ -1028,9 +1020,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
var errUnauthorized = fmt.Errorf("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err))
......@@ -1042,9 +1034,9 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(authenticate)
token, err := auth.GetAuthToken(ctx, authRedir)
if err != nil {
return nil, err
}
......@@ -1071,58 +1063,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, errUnauthorized
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func getValue(header, key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
......@@ -1146,10 +1086,10 @@ func getValue(header, key string) string {
return header[startIdx:endIdx]
}
func ParseAuthRedirectString(authStr string) AuthRedirect {
func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
authStr = strings.TrimPrefix(authStr, "Bearer ")
return AuthRedirect{
return auth.AuthRedirect{
Realm: getValue(authStr, "realm"),
Service: getValue(authStr, "service"),
Scope: getValue(authStr, "scope"),
......
......@@ -25,6 +25,7 @@ import (
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/gpu"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/openai"
......@@ -479,7 +480,7 @@ func PullModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &RegistryOptions{
regOpts := &auth.RegistryOptions{
Insecure: req.Insecure,
}
......@@ -528,7 +529,7 @@ func PushModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &RegistryOptions{
regOpts := &auth.RegistryOptions{
Insecure: req.Insecure,
}
......
......@@ -18,6 +18,7 @@ import (
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
)
......@@ -49,7 +50,7 @@ const (
maxUploadPartSize int64 = 1000 * format.MegaByte
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
......@@ -121,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
......@@ -212,7 +213,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
b.done = true
}
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
......@@ -227,7 +228,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
......@@ -277,9 +278,9 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(authenticate)
token, err := auth.GetAuthToken(ctx, authRedir)
if err != nil {
return err
}
......@@ -364,7 +365,7 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment