Unverified Commit c0960e29 authored by Bruce MacDonald's avatar Bruce MacDonald Committed by GitHub
Browse files

retry on concurrent request failure (#1483)

- remove parallel
parent 5314fc9b
...@@ -412,10 +412,6 @@ func newLlama(model string, adapters, projectors []string, runners []ModelRunner ...@@ -412,10 +412,6 @@ func newLlama(model string, adapters, projectors []string, runners []ModelRunner
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
params := append(params, "--port", strconv.Itoa(port)) params := append(params, "--port", strconv.Itoa(port))
if runner.Type == "gguf" {
params = append(params, "--parallel", "2")
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext( cmd := exec.CommandContext(
ctx, ctx,
...@@ -549,6 +545,8 @@ type prediction struct { ...@@ -549,6 +545,8 @@ type prediction struct {
} }
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
const retryDelay = 1 * time.Second
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
...@@ -570,6 +568,11 @@ type PredictResult struct { ...@@ -570,6 +568,11 @@ type PredictResult struct {
EvalDuration time.Duration EvalDuration time.Duration
} }
// IsRetryable checks if the line matches a condition that can be retried
func isRetryable(line []byte) bool {
return bytes.Contains(line, []byte("slot unavailable"))
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData imageData := llm.ImageData
if len(predict.Images) > 0 { if len(predict.Images) > 0 {
...@@ -607,98 +610,116 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred ...@@ -607,98 +610,116 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
request["grammar"] = jsonGrammar request["grammar"] = jsonGrammar
} }
// Handling JSON marshaling with special characters unescaped. for retries := 0; retries < maxRetries; retries++ {
buffer := &bytes.Buffer{} if retries > 0 {
enc := json.NewEncoder(buffer) time.Sleep(retryDelay) // wait before retrying
enc.SetEscapeHTML(false) }
if err := enc.Encode(request); err != nil { // Handling JSON marshaling with special characters unescaped.
return fmt.Errorf("failed to marshal data: %v", err) buffer := &bytes.Buffer{}
} enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) if err := enc.Encode(request); err != nil {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) return fmt.Errorf("failed to marshal data: %v", err)
if err != nil { }
return fmt.Errorf("error creating POST request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
if err != nil { req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
return fmt.Errorf("POST predict: %v", err) if err != nil {
} return fmt.Errorf("error creating POST request: %v", err)
defer resp.Body.Close() }
req.Header.Set("Content-Type", "application/json")
if resp.StatusCode >= 400 { resp, err := http.DefaultClient.Do(req)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err) return fmt.Errorf("POST predict: %v", err)
} }
log.Printf("llm predict error: %s", bodyBytes) defer resp.Body.Close()
return fmt.Errorf("%s", bodyBytes)
}
scanner := bufio.NewScanner(resp.Body) if resp.StatusCode >= 400 {
// increase the buffer size to avoid running out of space bodyBytes, err := io.ReadAll(resp.Body)
buf := make([]byte, 0, maxBufferSize) if err != nil {
scanner.Buffer(buf, maxBufferSize) return fmt.Errorf("failed reading llm error response: %w", err)
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
} }
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
evt, ok := bytes.CutPrefix(line, []byte("data: ")) scanner := bufio.NewScanner(resp.Body)
if !ok { // increase the buffer size to avoid running out of space
return fmt.Errorf("error parsing llm response stream: %s", line) buf := make([]byte, 0, maxBufferSize)
} scanner.Buffer(buf, maxBufferSize)
var p prediction retryNeeded := false
if err := json.Unmarshal(evt, &p); err != nil { for scanner.Scan() {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err) select {
} case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if p.Content != "" { if isRetryable(line) {
fn(PredictResult{ retryNeeded = true
CreatedAt: time.Now().UTC(), break
Content: p.Content, }
})
}
if p.Stop { evt, ok := bytes.CutPrefix(line, []byte("data: "))
fn(PredictResult{ if !ok {
CreatedAt: time.Now().UTC(), return fmt.Errorf("error parsing llm response stream: %s", line)
TotalDuration: time.Since(predict.CheckpointStart), }
Done: true, var p prediction
PromptEvalCount: p.Timings.PromptN, if err := json.Unmarshal(evt, &p); err != nil {
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
EvalCount: p.Timings.PredictedN, }
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
}) if p.Content != "" {
return nil fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
if p.Stop {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
} }
} }
}
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") { if strings.Contains(err.Error(), "unexpected EOF") {
// this means the llama runner subprocess crashed // this means the llama runner subprocess crashed
llm.Close() llm.Close()
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" { if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg) return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
}
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
} }
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model") return fmt.Errorf("error reading llm response: %v", err)
}
if !retryNeeded {
return nil // success
} }
return fmt.Errorf("error reading llm response: %v", err)
} }
return nil // should never reach here ideally
return fmt.Errorf("max retries exceeded")
} }
type TokenizeRequest struct { type TokenizeRequest struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment