Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
0a892419
Unverified
Commit
0a892419
authored
Aug 21, 2023
by
Ryan Baker
Committed by
GitHub
Aug 21, 2023
Browse files
Strip protocol from model path (#377)
parent
e3054fc7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
231 additions
and
43 deletions
+231
-43
cmd/cmd.go
cmd/cmd.go
+20
-6
server/images.go
server/images.go
+44
-10
server/modelpath.go
server/modelpath.go
+39
-26
server/modelpath_test.go
server/modelpath_test.go
+122
-0
server/routes.go
server/routes.go
+6
-1
No files found.
cmd/cmd.go
View file @
0a892419
...
@@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
...
@@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
}
func
RunHandler
(
cmd
*
cobra
.
Command
,
args
[]
string
)
error
{
func
RunHandler
(
cmd
*
cobra
.
Command
,
args
[]
string
)
error
{
mp
:=
server
.
ParseModelPath
(
args
[
0
])
insecure
,
err
:=
cmd
.
Flags
()
.
GetBool
(
"insecure"
)
if
err
!=
nil
{
return
err
}
mp
,
err
:=
server
.
ParseModelPath
(
args
[
0
],
insecure
)
if
err
!=
nil
{
return
err
}
fp
,
err
:=
mp
.
GetManifestPath
(
false
)
fp
,
err
:=
mp
.
GetManifestPath
(
false
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
...
@@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
...
@@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
_
,
err
=
os
.
Stat
(
fp
)
_
,
err
=
os
.
Stat
(
fp
)
switch
{
switch
{
case
errors
.
Is
(
err
,
os
.
ErrNotExist
)
:
case
errors
.
Is
(
err
,
os
.
ErrNotExist
)
:
if
err
:=
pull
(
args
[
0
],
fals
e
);
err
!=
nil
{
if
err
:=
pull
(
args
[
0
],
insecur
e
);
err
!=
nil
{
var
apiStatusError
api
.
StatusError
var
apiStatusError
api
.
StatusError
if
!
errors
.
As
(
err
,
&
apiStatusError
)
{
if
!
errors
.
As
(
err
,
&
apiStatusError
)
{
return
err
return
err
...
@@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
...
@@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
case
strings
.
HasPrefix
(
line
,
"/show"
)
:
case
strings
.
HasPrefix
(
line
,
"/show"
)
:
args
:=
strings
.
Fields
(
line
)
args
:=
strings
.
Fields
(
line
)
if
len
(
args
)
>
1
{
if
len
(
args
)
>
1
{
mp
:=
server
.
ParseModelPath
(
model
)
mp
,
err
:=
server
.
ParseModelPath
(
model
,
false
)
if
err
!=
nil
{
return
err
}
manifest
,
err
:=
server
.
GetManifest
(
mp
)
manifest
,
err
:=
server
.
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
fmt
.
Println
(
"error: couldn't get a manifest for this model"
)
fmt
.
Println
(
"error: couldn't get a manifest for this model"
)
...
@@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
...
@@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
}
}
func
RunServer
(
cmd
*
cobra
.
Command
,
_
[]
string
)
error
{
func
RunServer
(
cmd
*
cobra
.
Command
,
_
[]
string
)
error
{
var
host
,
port
=
"127.0.0.1"
,
"11434"
host
,
port
:
=
"127.0.0.1"
,
"11434"
parts
:=
strings
.
Split
(
os
.
Getenv
(
"OLLAMA_HOST"
),
":"
)
parts
:=
strings
.
Split
(
os
.
Getenv
(
"OLLAMA_HOST"
),
":"
)
if
ip
:=
net
.
ParseIP
(
parts
[
0
]);
ip
!=
nil
{
if
ip
:=
net
.
ParseIP
(
parts
[
0
]);
ip
!=
nil
{
...
@@ -630,7 +643,7 @@ func initializeKeypair() error {
...
@@ -630,7 +643,7 @@ func initializeKeypair() error {
return
fmt
.
Errorf
(
"could not create directory %w"
,
err
)
return
fmt
.
Errorf
(
"could not create directory %w"
,
err
)
}
}
err
=
os
.
WriteFile
(
privKeyPath
,
pem
.
EncodeToMemory
(
privKeyBytes
),
0600
)
err
=
os
.
WriteFile
(
privKeyPath
,
pem
.
EncodeToMemory
(
privKeyBytes
),
0
o
600
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
...
@@ -642,7 +655,7 @@ func initializeKeypair() error {
...
@@ -642,7 +655,7 @@ func initializeKeypair() error {
pubKeyData
:=
ssh
.
MarshalAuthorizedKey
(
sshPrivateKey
.
PublicKey
())
pubKeyData
:=
ssh
.
MarshalAuthorizedKey
(
sshPrivateKey
.
PublicKey
())
err
=
os
.
WriteFile
(
pubKeyPath
,
pubKeyData
,
0644
)
err
=
os
.
WriteFile
(
pubKeyPath
,
pubKeyData
,
0
o
644
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
...
@@ -737,6 +750,7 @@ func NewCLI() *cobra.Command {
...
@@ -737,6 +750,7 @@ func NewCLI() *cobra.Command {
}
}
runCmd
.
Flags
()
.
Bool
(
"verbose"
,
false
,
"Show timings for response"
)
runCmd
.
Flags
()
.
Bool
(
"verbose"
,
false
,
"Show timings for response"
)
runCmd
.
Flags
()
.
Bool
(
"insecure"
,
false
,
"Use an insecure registry"
)
serveCmd
:=
&
cobra
.
Command
{
serveCmd
:=
&
cobra
.
Command
{
Use
:
"serve"
,
Use
:
"serve"
,
...
...
server/images.go
View file @
0a892419
...
@@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
...
@@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
}
}
func
GetModel
(
name
string
)
(
*
Model
,
error
)
{
func
GetModel
(
name
string
)
(
*
Model
,
error
)
{
mp
:=
ParseModelPath
(
name
)
mp
,
err
:=
ParseModelPath
(
name
,
false
)
if
err
!=
nil
{
return
nil
,
err
}
manifest
,
err
:=
GetManifest
(
mp
)
manifest
,
err
:=
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
...
@@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
case
"model"
:
case
"model"
:
fn
(
api
.
ProgressResponse
{
Status
:
"looking for model"
})
fn
(
api
.
ProgressResponse
{
Status
:
"looking for model"
})
embed
.
model
=
c
.
Args
embed
.
model
=
c
.
Args
mp
:=
ParseModelPath
(
c
.
Args
)
mp
,
err
:=
ParseModelPath
(
c
.
Args
,
false
)
if
err
!=
nil
{
return
err
}
mf
,
err
:=
GetManifest
(
mp
)
mf
,
err
:=
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
modelFile
,
err
:=
filenameWithPath
(
path
,
c
.
Args
)
modelFile
,
err
:=
filenameWithPath
(
path
,
c
.
Args
)
...
@@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
...
@@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if
err
:=
PullModel
(
ctx
,
c
.
Args
,
&
RegistryOptions
{},
fn
);
err
!=
nil
{
if
err
:=
PullModel
(
ctx
,
c
.
Args
,
&
RegistryOptions
{},
fn
);
err
!=
nil
{
return
err
return
err
}
}
mf
,
err
=
GetManifest
(
ParseModelPath
(
c
.
Args
)
)
mf
,
err
=
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to open file after pull: %v"
,
err
)
return
fmt
.
Errorf
(
"failed to open file after pull: %v"
,
err
)
}
}
...
@@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
...
@@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
}
}
func
CreateManifest
(
name
string
,
cfg
*
LayerReader
,
layers
[]
*
Layer
)
error
{
func
CreateManifest
(
name
string
,
cfg
*
LayerReader
,
layers
[]
*
Layer
)
error
{
mp
:=
ParseModelPath
(
name
)
mp
,
err
:=
ParseModelPath
(
name
,
false
)
if
err
!=
nil
{
return
err
}
manifest
:=
ManifestV2
{
manifest
:=
ManifestV2
{
SchemaVersion
:
2
,
SchemaVersion
:
2
,
...
@@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
...
@@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
}
}
func
CopyModel
(
src
,
dest
string
)
error
{
func
CopyModel
(
src
,
dest
string
)
error
{
srcPath
,
err
:=
ParseModelPath
(
src
)
.
GetManifestPath
(
false
)
srcModelPath
,
err
:=
ParseModelPath
(
src
,
false
)
if
err
!=
nil
{
return
err
}
srcPath
,
err
:=
srcModelPath
.
GetManifestPath
(
false
)
if
err
!=
nil
{
return
err
}
destModelPath
,
err
:=
ParseModelPath
(
dest
,
false
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
destPath
,
err
:=
ParseModelPath
(
dest
)
.
GetManifestPath
(
true
)
destPath
,
err
:=
destModelPath
.
GetManifestPath
(
true
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
...
@@ -832,7 +854,10 @@ func CopyModel(src, dest string) error {
...
@@ -832,7 +854,10 @@ func CopyModel(src, dest string) error {
}
}
func
DeleteModel
(
name
string
)
error
{
func
DeleteModel
(
name
string
)
error
{
mp
:=
ParseModelPath
(
name
)
mp
,
err
:=
ParseModelPath
(
name
,
false
)
if
err
!=
nil
{
return
err
}
manifest
,
err
:=
GetManifest
(
mp
)
manifest
,
err
:=
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -859,7 +884,10 @@ func DeleteModel(name string) error {
...
@@ -859,7 +884,10 @@ func DeleteModel(name string) error {
return
nil
return
nil
}
}
tag
:=
path
[
:
slashIndex
]
+
":"
+
path
[
slashIndex
+
1
:
]
tag
:=
path
[
:
slashIndex
]
+
":"
+
path
[
slashIndex
+
1
:
]
fmp
:=
ParseModelPath
(
tag
)
fmp
,
err
:=
ParseModelPath
(
tag
,
false
)
if
err
!=
nil
{
return
err
}
// skip the manifest we're trying to delete
// skip the manifest we're trying to delete
if
mp
.
GetFullTagname
()
==
fmp
.
GetFullTagname
()
{
if
mp
.
GetFullTagname
()
==
fmp
.
GetFullTagname
()
{
...
@@ -912,7 +940,10 @@ func DeleteModel(name string) error {
...
@@ -912,7 +940,10 @@ func DeleteModel(name string) error {
}
}
func
PushModel
(
ctx
context
.
Context
,
name
string
,
regOpts
*
RegistryOptions
,
fn
func
(
api
.
ProgressResponse
))
error
{
func
PushModel
(
ctx
context
.
Context
,
name
string
,
regOpts
*
RegistryOptions
,
fn
func
(
api
.
ProgressResponse
))
error
{
mp
:=
ParseModelPath
(
name
)
mp
,
err
:=
ParseModelPath
(
name
,
regOpts
.
Insecure
)
if
err
!=
nil
{
return
err
}
fn
(
api
.
ProgressResponse
{
Status
:
"retrieving manifest"
})
fn
(
api
.
ProgressResponse
{
Status
:
"retrieving manifest"
})
...
@@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
...
@@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
}
}
func
PullModel
(
ctx
context
.
Context
,
name
string
,
regOpts
*
RegistryOptions
,
fn
func
(
api
.
ProgressResponse
))
error
{
func
PullModel
(
ctx
context
.
Context
,
name
string
,
regOpts
*
RegistryOptions
,
fn
func
(
api
.
ProgressResponse
))
error
{
mp
:=
ParseModelPath
(
name
)
mp
,
err
:=
ParseModelPath
(
name
,
regOpts
.
Insecure
)
if
err
!=
nil
{
return
err
}
fn
(
api
.
ProgressResponse
{
Status
:
"pulling manifest"
})
fn
(
api
.
ProgressResponse
{
Status
:
"pulling manifest"
})
...
...
server/modelpath.go
View file @
0a892419
package
server
package
server
import
(
import
(
"errors"
"fmt"
"fmt"
"os"
"os"
"path/filepath"
"path/filepath"
...
@@ -23,42 +24,54 @@ const (
...
@@ -23,42 +24,54 @@ const (
DefaultProtocolScheme
=
"https"
DefaultProtocolScheme
=
"https"
)
)
func
ParseModelPath
(
name
string
)
ModelPath
{
var
(
slashParts
:=
strings
.
Split
(
name
,
"/"
)
ErrInvalidImageFormat
=
errors
.
New
(
"invalid image format"
)
var
registry
,
namespace
,
repository
,
tag
string
ErrInvalidProtocol
=
errors
.
New
(
"invalid protocol scheme"
)
ErrInsecureProtocol
=
errors
.
New
(
"insecure protocol http"
)
)
func
ParseModelPath
(
name
string
,
allowInsecure
bool
)
(
ModelPath
,
error
)
{
mp
:=
ModelPath
{
ProtocolScheme
:
DefaultProtocolScheme
,
Registry
:
DefaultRegistry
,
Namespace
:
DefaultNamespace
,
Repository
:
""
,
Tag
:
DefaultTag
,
}
protocol
,
rest
,
didSplit
:=
strings
.
Cut
(
name
,
"://"
)
if
didSplit
{
if
protocol
==
"https"
||
protocol
==
"http"
&&
allowInsecure
{
mp
.
ProtocolScheme
=
protocol
name
=
rest
}
else
if
protocol
==
"http"
&&
!
allowInsecure
{
return
ModelPath
{},
ErrInsecureProtocol
}
else
{
return
ModelPath
{},
ErrInvalidProtocol
}
}
slashParts
:=
strings
.
Split
(
name
,
"/"
)
switch
len
(
slashParts
)
{
switch
len
(
slashParts
)
{
case
3
:
case
3
:
r
egistry
=
slashParts
[
0
]
mp
.
R
egistry
=
slashParts
[
0
]
n
amespace
=
slashParts
[
1
]
mp
.
N
amespace
=
slashParts
[
1
]
r
epository
=
strings
.
Split
(
slashParts
[
2
]
,
":"
)[
0
]
mp
.
R
epository
=
slashParts
[
2
]
case
2
:
case
2
:
registry
=
DefaultRegistry
mp
.
Namespace
=
slashParts
[
0
]
namespace
=
slashParts
[
0
]
mp
.
Repository
=
slashParts
[
1
]
repository
=
strings
.
Split
(
slashParts
[
1
],
":"
)[
0
]
case
1
:
case
1
:
registry
=
DefaultRegistry
mp
.
Repository
=
slashParts
[
0
]
namespace
=
DefaultNamespace
repository
=
strings
.
Split
(
slashParts
[
0
],
":"
)[
0
]
default
:
default
:
fmt
.
Println
(
"Invalid image format."
)
return
ModelPath
{},
ErrInvalidImageFormat
return
ModelPath
{}
}
}
colonParts
:=
strings
.
Split
(
slashParts
[
len
(
slashParts
)
-
1
],
":"
)
if
repo
,
tag
,
didSplit
:=
strings
.
Cut
(
mp
.
Repository
,
":"
);
didSplit
{
if
len
(
colonParts
)
==
2
{
mp
.
Repository
=
repo
tag
=
colonParts
[
1
]
mp
.
Tag
=
tag
}
else
{
tag
=
DefaultTag
}
}
return
ModelPath
{
return
mp
,
nil
ProtocolScheme
:
DefaultProtocolScheme
,
Registry
:
registry
,
Namespace
:
namespace
,
Repository
:
repository
,
Tag
:
tag
,
}
}
}
func
(
mp
ModelPath
)
GetNamespaceRepository
()
string
{
func
(
mp
ModelPath
)
GetNamespaceRepository
()
string
{
...
...
server/modelpath_test.go
0 → 100644
View file @
0a892419
package
server
import
"testing"
func
TestParseModelPath
(
t
*
testing
.
T
)
{
type
input
struct
{
name
string
allowInsecure
bool
}
tests
:=
[]
struct
{
name
string
args
input
want
ModelPath
wantErr
error
}{
{
"full path https"
,
input
{
"https://example.com/ns/repo:tag"
,
false
},
ModelPath
{
ProtocolScheme
:
"https"
,
Registry
:
"example.com"
,
Namespace
:
"ns"
,
Repository
:
"repo"
,
Tag
:
"tag"
,
},
nil
,
},
{
"full path http without insecure"
,
input
{
"http://example.com/ns/repo:tag"
,
false
},
ModelPath
{},
ErrInsecureProtocol
,
},
{
"full path http with insecure"
,
input
{
"http://example.com/ns/repo:tag"
,
true
},
ModelPath
{
ProtocolScheme
:
"http"
,
Registry
:
"example.com"
,
Namespace
:
"ns"
,
Repository
:
"repo"
,
Tag
:
"tag"
,
},
nil
,
},
{
"full path invalid protocol"
,
input
{
"file://example.com/ns/repo:tag"
,
false
},
ModelPath
{},
ErrInvalidProtocol
,
},
{
"no protocol"
,
input
{
"example.com/ns/repo:tag"
,
false
},
ModelPath
{
ProtocolScheme
:
"https"
,
Registry
:
"example.com"
,
Namespace
:
"ns"
,
Repository
:
"repo"
,
Tag
:
"tag"
,
},
nil
,
},
{
"no registry"
,
input
{
"ns/repo:tag"
,
false
},
ModelPath
{
ProtocolScheme
:
"https"
,
Registry
:
DefaultRegistry
,
Namespace
:
"ns"
,
Repository
:
"repo"
,
Tag
:
"tag"
,
},
nil
,
},
{
"no namespace"
,
input
{
"repo:tag"
,
false
},
ModelPath
{
ProtocolScheme
:
"https"
,
Registry
:
DefaultRegistry
,
Namespace
:
DefaultNamespace
,
Repository
:
"repo"
,
Tag
:
"tag"
,
},
nil
,
},
{
"no tag"
,
input
{
"repo"
,
false
},
ModelPath
{
ProtocolScheme
:
"https"
,
Registry
:
DefaultRegistry
,
Namespace
:
DefaultNamespace
,
Repository
:
"repo"
,
Tag
:
DefaultTag
,
},
nil
,
},
{
"invalid image format"
,
input
{
"example.com/a/b/c"
,
false
},
ModelPath
{},
ErrInvalidImageFormat
,
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
got
,
err
:=
ParseModelPath
(
tc
.
args
.
name
,
tc
.
args
.
allowInsecure
)
if
err
!=
tc
.
wantErr
{
t
.
Errorf
(
"got: %q want: %q"
,
err
,
tc
.
wantErr
)
}
if
got
!=
tc
.
want
{
t
.
Errorf
(
"got: %q want: %q"
,
got
,
tc
.
want
)
}
})
}
}
server/routes.go
View file @
0a892419
...
@@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) {
...
@@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) {
return
nil
return
nil
}
}
tag
:=
path
[
:
slashIndex
]
+
":"
+
path
[
slashIndex
+
1
:
]
tag
:=
path
[
:
slashIndex
]
+
":"
+
path
[
slashIndex
+
1
:
]
mp
:=
ParseModelPath
(
tag
)
mp
,
err
:=
ParseModelPath
(
tag
,
false
)
if
err
!=
nil
{
return
err
}
manifest
,
err
:=
GetManifest
(
mp
)
manifest
,
err
:=
GetManifest
(
mp
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"skipping file: %s"
,
fp
)
log
.
Printf
(
"skipping file: %s"
,
fp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment