Unverified Commit 4dc4f1be authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

types/model: restrict digest hash part to a minimum of 2 characters (#3858)

This allows users of a valid Digest to know it has a minimum of 2
characters in the hash part for use when sharding.

This is a reasonable restriction as the hash part is a SHA256 hash which
is 64 characters long, which is the common hash used. There is no
anticipation of using a hash with less than 2 characters.

Also, add MustParseDigest.

Also, replace Digest.Type with Digest.Split for getting both the type
and hash parts together, which is most the common case when asking for
either.
parent 16b52331
...@@ -15,14 +15,10 @@ type Digest struct { ...@@ -15,14 +15,10 @@ type Digest struct {
s string s string
} }
// Type returns the digest type of the digest. // Split returns the digest type and the digest value.
// func (d Digest) Split() (typ, digest string) {
// Example: typ, digest, _ = strings.Cut(d.s, "-")
// return
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
} }
// String returns the digest in the form of "<digest-type>-<digest>", or the // String returns the digest in the form of "<digest-type>-<digest>", or the
...@@ -51,12 +47,20 @@ func ParseDigest(s string) Digest { ...@@ -51,12 +47,20 @@ func ParseDigest(s string) Digest {
if !ok { if !ok {
typ, digest, ok = strings.Cut(s, ":") typ, digest, ok = strings.Cut(s, ":")
} }
if ok && isValidDigestType(typ) && isValidHex(digest) { if ok && isValidDigestType(typ) && isValidHex(digest) && len(digest) >= 2 {
return Digest{s: fmt.Sprintf("%s-%s", typ, digest)} return Digest{s: fmt.Sprintf("%s-%s", typ, digest)}
} }
return Digest{} return Digest{}
} }
func MustParseDigest(s string) Digest {
d := ParseDigest(s)
if !d.IsValid() {
panic(fmt.Sprintf("invalid digest: %q", s))
}
return d
}
func isValidDigestType(s string) bool { func isValidDigestType(s string) bool {
if len(s) == 0 { if len(s) == 0 {
return false return false
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"hash/maphash" "hash/maphash"
"io" "io"
"log/slog" "log/slog"
"path"
"path/filepath" "path/filepath"
"slices" "slices"
"strings" "strings"
...@@ -589,10 +590,20 @@ func ParseNameFromURLPath(s, fill string) Name { ...@@ -589,10 +590,20 @@ func ParseNameFromURLPath(s, fill string) Name {
// Example: // Example:
// //
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag" // ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
func (r Name) URLPath() string { func (r Name) DisplayURLPath() string {
return r.DisplayShortest(MaskNothing) return r.DisplayShortest(MaskNothing)
} }
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
// complete Name in the form:
//
// <host>/<namespace>/<model>/<tag>
//
// The parts are downcased.
func (r Name) URLPath() string {
return strings.ToLower(path.Join(r.parts[:PartBuild]...))
}
// ParseNameFromFilepath parses a file path into a Name. The input string must be a // ParseNameFromFilepath parses a file path into a Name. The input string must be a
// valid file path representation of a model name in the form: // valid file path representation of a model name in the form:
// //
......
...@@ -50,10 +50,10 @@ var testNames = map[string]fields{ ...@@ -50,10 +50,10 @@ var testNames = map[string]fields{
"mistral:latest@": {}, "mistral:latest@": {},
// resolved // resolved
"x@sha123-1": {model: "x", digest: "sha123-1"}, "x@sha123-12": {model: "x", digest: "sha123-12"},
"@sha456-2": {digest: "sha456-2"}, "@sha456-22": {digest: "sha456-22"},
"@sha456-1": {},
"@@sha123-1": {}, "@@sha123-22": {},
// preserves case for build // preserves case for build
"x+b": {model: "x", build: "b"}, "x+b": {model: "x", build: "b"},
...@@ -485,7 +485,7 @@ func TestNamePath(t *testing.T) { ...@@ -485,7 +485,7 @@ func TestNamePath(t *testing.T) {
t.Run(tt.in, func(t *testing.T) { t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing) p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p) t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.URLPath(); g != tt.want { if g := p.DisplayURLPath(); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want) t.Errorf("got = %q; want %q", g, tt.want)
} }
}) })
...@@ -678,18 +678,18 @@ func ExampleName_CompareFold_sort() { ...@@ -678,18 +678,18 @@ func ExampleName_CompareFold_sort() {
func ExampleName_completeAndResolved() { func ExampleName_completeAndResolved() {
for _, s := range []string{ for _, s := range []string{
"x/y/z:latest+q4_0@sha123-1", "x/y/z:latest+q4_0@sha123-abc",
"x/y/z:latest+q4_0", "x/y/z:latest+q4_0",
"@sha123-1", "@sha123-abc",
} { } {
name := ParseName(s, FillNothing) name := ParseName(s, FillNothing)
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest()) fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
} }
// Output: // Output:
// complete:true resolved:true digest:sha123-1 // complete:true resolved:true digest:sha123-abc
// complete:true resolved:false digest: // complete:true resolved:false digest:
// complete:false resolved:true digest:sha123-1 // complete:false resolved:true digest:sha123-abc
} }
func ExampleName_DisplayShortest() { func ExampleName_DisplayShortest() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment