"tests/vscode:/vscode.git/clone" did not exist on "15d4cf15f608be5d923b7a1b1ddaa5541f4ce069"
cli.py 3.45 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
import os
2
import sys
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
5
import typer

from pathlib import Path
6
from loguru import logger
7
from typing import Optional
Olivier Dehaene's avatar
Olivier Dehaene committed
8
9
10
11
12
13


app = typer.Typer()


@app.command()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
def serve(
15
    model_id: str,
16
    revision: Optional[str] = None,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    sharded: bool = False,
18
    quantize: bool = False,
19
    uds_path: Path = "/tmp/text-generation-server",
20
21
    logger_level: str = "INFO",
    json_output: bool = False,
22
    otlp_endpoint: Optional[str] = None,
Olivier Dehaene's avatar
Olivier Dehaene committed
23
):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

38
39
40
41
42
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
43
        filter="text_generation_server",
44
45
46
47
48
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
49
50
51
52
53

    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import server
    from text_generation_server.tracing import setup_tracing

54
55
56
57
    # Setup OpenTelemetry distributed tracing
    if otlp_endpoint is not None:
        setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

58
    server.serve(model_id, revision, sharded, quantize, uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
59
60
61


@app.command()
Nicolas Patry's avatar
Nicolas Patry committed
62
def download_weights(
63
    model_id: str,
64
    revision: Optional[str] = None,
65
    extension: str = ".safetensors",
66
67
    logger_level: str = "INFO",
    json_output: bool = False,
Olivier Dehaene's avatar
Olivier Dehaene committed
68
):
69
70
71
72
73
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
74
        filter="text_generation_server",
75
76
77
78
79
80
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

81
82
83
    # Import here after the logger is added to log potential import exceptions
    from text_generation_server import utils

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # Test if files were already download
    try:
        utils.weight_files(model_id, revision, extension)
        logger.info(
            "Files are already present in the local cache. " "Skipping download."
        )
        return
    # Local files not found
    except utils.LocalEntryNotFoundError:
        pass

    # Download weights directly
    try:
        filenames = utils.weight_hub_files(model_id, revision, extension)
        utils.download_weights(filenames, model_id, revision)
    except utils.EntryNotFoundError as e:
        if not extension == ".safetensors":
            raise e

        logger.warning(
            f"No safetensors weights found for model {model_id} at revision {revision}. "
            f"Converting PyTorch weights instead."
        )

        # Try to see if there are pytorch weights
        pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
        # Download pytorch weights
        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
        local_st_files = [
            p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
            for p in local_pt_files
        ]
        # Convert pytorch weights to safetensors
        utils.convert_files(local_pt_files, local_st_files)
Olivier Dehaene's avatar
Olivier Dehaene committed
118
119
120
121


if __name__ == "__main__":
    app()