"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "454f8be7b9e5a8a9568f3bb5b569ebff4b4e4d92"
Unverified Commit 1deb35ca authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

use loaded llm for generating model file embeddings

parents e2de8868 326de489
...@@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} embed := EmbeddingParams{fn: fn}
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
...@@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
} else { } else {
embed.model = modelFile
// create a model from this specified file // create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"}) fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile) file, err := os.Open(modelFile)
...@@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, l) layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model // apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions() embed.opts = formattedParams
embed.opts.FromMap(formattedParams)
} }
// generate the embedding layers // generate the embedding layers
...@@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api ...@@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
type EmbeddingParams struct { type EmbeddingParams struct {
model string model string
opts api.Options opts map[string]interface{}
files []string // paths to files to embed files []string // paths to files to embed
fn func(resp api.ProgressResponse) fn func(resp api.ProgressResponse)
} }
...@@ -478,32 +478,22 @@ type EmbeddingParams struct { ...@@ -478,32 +478,22 @@ type EmbeddingParams struct {
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{} layers := []*LayerReader{}
if len(e.files) > 0 { if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil { // check if the model is a file path or a model name
if os.IsNotExist(err) { model, err := GetModel(e.model)
// this is a model name rather than the file if err != nil {
model, err := GetModel(e.model) if !strings.Contains(err.Error(), "couldn't open file") {
if err != nil { return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
}
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
} }
// the model may be a file path, create a model from this file
model = &Model{ModelPath: e.model}
} }
e.opts.EmbeddingOnly = true if err := load(model, e.opts, defaultSessionDuration); err != nil {
llmModel, err := llm.New(e.model, []string{}, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }
defer func() {
if llmModel != nil {
llmModel.Close()
}
}()
// this will be used to check if we already have embeddings for a file // this will be used to check if we already have embeddings for a file
modelInfo, err := os.Stat(e.model) modelInfo, err := os.Stat(model.ModelPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get model file info: %v", err) return nil, fmt.Errorf("failed to get model file info: %v", err)
} }
...@@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { ...@@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]}) embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
continue continue
} }
embed, err := llmModel.Embedding(d) embed, err := loaded.llm.Embedding(d)
if err != nil { if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue continue
......
...@@ -38,6 +38,8 @@ var loaded struct { ...@@ -38,6 +38,8 @@ var loaded struct {
options api.Options options api.Options
} }
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions() opts := api.DefaultOptions()
...@@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) { ...@@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := 5 * time.Minute sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
if err := load(model, req.Options, sessionDuration); err != nil { if err := load(model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment