"graphbolt/src/vscode:/vscode.git/clone" did not exist on "aad12df6a63b7c2269bc8ed68b10b9099b5df46d"
Unverified Commit 1deb35ca authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

use loaded llm for generating model file embeddings

parents e2de8868 326de489
......@@ -263,7 +263,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
var layers []*LayerReader
params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
embed := EmbeddingParams{fn: fn}
for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name {
......@@ -291,6 +291,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err
}
} else {
embed.model = modelFile
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
......@@ -422,8 +423,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, l)
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
embed.opts = api.DefaultOptions()
embed.opts.FromMap(formattedParams)
embed.opts = formattedParams
}
// generate the embedding layers
......@@ -469,7 +469,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
type EmbeddingParams struct {
model string
opts api.Options
opts map[string]interface{}
files []string // paths to files to embed
fn func(resp api.ProgressResponse)
}
......@@ -478,32 +478,22 @@ type EmbeddingParams struct {
func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{}
if len(e.files) > 0 {
if _, err := os.Stat(e.model); err != nil {
if os.IsNotExist(err) {
// this is a model name rather than the file
model, err := GetModel(e.model)
if err != nil {
return nil, fmt.Errorf("failed to get model to generate embeddings: %v", err)
}
e.model = model.ModelPath
} else {
return nil, fmt.Errorf("failed to get model file to generate embeddings: %v", err)
// check if the model is a file path or a model name
model, err := GetModel(e.model)
if err != nil {
if !strings.Contains(err.Error(), "couldn't open file") {
return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
}
// the model may be a file path, create a model from this file
model = &Model{ModelPath: e.model}
}
e.opts.EmbeddingOnly = true
llmModel, err := llm.New(e.model, []string{}, e.opts)
if err != nil {
if err := load(model, e.opts, defaultSessionDuration); err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
defer func() {
if llmModel != nil {
llmModel.Close()
}
}()
// this will be used to check if we already have embeddings for a file
modelInfo, err := os.Stat(e.model)
modelInfo, err := os.Stat(model.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to get model file info: %v", err)
}
......@@ -561,7 +551,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
continue
}
embed, err := llmModel.Embedding(d)
embed, err := loaded.llm.Embedding(d)
if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
......
......@@ -38,6 +38,8 @@ var loaded struct {
options api.Options
}
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions()
......@@ -134,7 +136,7 @@ func GenerateHandler(c *gin.Context) {
return
}
sessionDuration := 5 * time.Minute
sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
if err := load(model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment