Commit 5ade3db0 authored by Michael Yang's avatar Michael Yang
Browse files

fix race

block on write which only returns when the channel is closed. this is
contrary to the previous arrangement where the handler may return but
the stream hasn't finished writing. it can lead to the client receiving
unexpected responses (since the request has been handled) or worst case
a nil-pointer dereference as the stream tries to flush a nil writer
parent 965f9ad0
...@@ -58,9 +58,6 @@ func generate(c *gin.Context) { ...@@ -58,9 +58,6 @@ func generate(c *gin.Context) {
req.Model = path.Join(cacheDir(), "models", req.Model+".bin") req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
} }
ch := make(chan any)
go stream(c, ch)
templateNames := make([]string, 0, len(templates.Templates())) templateNames := make([]string, 0, len(templates.Templates()))
for _, template := range templates.Templates() { for _, template := range templates.Templates() {
templateNames = append(templateNames, template.Name()) templateNames = append(templateNames, template.Name())
...@@ -84,21 +81,21 @@ func generate(c *gin.Context) { ...@@ -84,21 +81,21 @@ func generate(c *gin.Context) {
} }
defer llm.Close() defer llm.Close()
fn := func(r api.GenerateResponse) { ch := make(chan any)
r.Model = req.Model go func() {
r.CreatedAt = time.Now().UTC() defer close(ch)
if r.Done { llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
r.TotalDuration = time.Since(start) r.Model = req.Model
} r.CreatedAt = time.Now().UTC()
if r.Done {
ch <- r r.TotalDuration = time.Since(start)
} }
if err := llm.Predict(req.Context, req.Prompt, fn); err != nil { ch <- r
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) })
return }()
}
streamResponse(c, ch)
} }
func pull(c *gin.Context) { func pull(c *gin.Context) {
...@@ -133,20 +130,18 @@ func pull(c *gin.Context) { ...@@ -133,20 +130,18 @@ func pull(c *gin.Context) {
} }
ch := make(chan any) ch := make(chan any)
go stream(c, ch) go func() {
defer close(ch)
fn := func(total, completed int64) { saveModel(remote, func(total, completed int64) {
ch <- api.PullProgress{ ch <- api.PullProgress{
Total: total, Total: total,
Completed: completed, Completed: completed,
Percent: float64(completed) / float64(total) * 100, Percent: float64(completed) / float64(total) * 100,
} }
} })
}()
if err := saveModel(remote, fn); err != nil { streamResponse(c, ch)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
...@@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i ...@@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return return
} }
func stream(c *gin.Context, ch chan any) { func streamResponse(c *gin.Context, ch chan any) {
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
val, ok := <-ch val, ok := <-ch
if !ok { if !ok {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment