server.py 2.03 KB
Newer Older
Bruce MacDonald's avatar
Bruce MacDonald committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import json
import os
from llama_cpp import Llama
from flask import Flask, Response, stream_with_context, request
from flask_cors import CORS, cross_origin

app = Flask(__name__)
CORS(app)  # enable CORS for all routes

# llms tracks which models are loaded
llms = {}


14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@app.route("/load", methods=["POST"])
def load():
    data = request.get_json()
    model = data.get("model")

    if not model:
        return Response("Model is required", status=400)
    if not os.path.exists(f"../models/{model}.bin"):
        return {"error": "The model does not exist."}, 400

    if model not in llms:
        llms[model] = Llama(model_path=f"../models/{model}.bin")

    return Response(status=204)


@app.route("/unload", methods=["POST"])
def unload():
    data = request.get_json()
    model = data.get("model")

    if not model:
        return Response("Model is required", status=400)
    if not os.path.exists(f"../models/{model}.bin"):
        return {"error": "The model does not exist."}, 400

    llms.pop(model, None)

    return Response(status=204)


Bruce MacDonald's avatar
Bruce MacDonald committed
45
46
47
48
49
50
51
52
53
54
55
@app.route("/generate", methods=["POST"])
def generate():
    data = request.get_json()
    model = data.get("model")
    prompt = data.get("prompt")

    if not model:
        return Response("Model is required", status=400)
    if not prompt:
        return Response("Prompt is required", status=400)
    if not os.path.exists(f"../models/{model}.bin"):
56
        return {"error": "The model does not exist."}, 400
Bruce MacDonald's avatar
Bruce MacDonald committed
57
58

    if model not in llms:
59
        # auto load
Bruce MacDonald's avatar
Bruce MacDonald committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        llms[model] = Llama(model_path=f"../models/{model}.bin")

    def stream_response():
        stream = llms[model](
            str(prompt),  # TODO: optimize prompt based on model
            max_tokens=4096,
            stop=["Q:", "\n"],
            echo=True,
            stream=True,
        )
        for output in stream:
            yield json.dumps(output)

    return Response(
        stream_with_context(stream_response()), mimetype="text/event-stream"
    )


if __name__ == "__main__":
    app.run(debug=True, threaded=True, port=5000)