globals.py 2.18 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3
from loguru import logger
4
5
6
from typing import Dict, Optional

from text_generation_server.utils.log import log_master
7

Nicolas Patry's avatar
Nicolas Patry committed
8
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
9
10
11
12
13
14
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
    "1",
    "true",
}
Nicolas Patry's avatar
Nicolas Patry committed
15
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
Nicolas Patry's avatar
Nicolas Patry committed
16
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
17
18
19
20
21
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
    ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
22

23
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
24
25
    raise RuntimeError("Prefix caching is only supported with flashinfer")

26
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
27
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
28
29
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
30

Nicolas Patry's avatar
Nicolas Patry committed
31
# This is overridden by the cli
32
33
34
35
36
37
38
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
    BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
    BLOCK_SIZE = 1
else:
    BLOCK_SIZE = 16
39

40
cuda_graphs = os.getenv("CUDA_GRAPHS")
Nicolas Patry's avatar
Nicolas Patry committed
41
if cuda_graphs is not None:
42
43
44
45
46
47
    try:
        cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
    except Exception as e:
        raise RuntimeError(
            f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
        )
48
49
else:
    cuda_graphs = None
50
51
52
53
54
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
    cuda_graphs.sort(reverse=True)

55
CUDA_GRAPHS = cuda_graphs
fxmarty's avatar
fxmarty committed
56

drbh's avatar
drbh committed
57
58
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
59
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
drbh's avatar
drbh committed
60
61
62
63
64


def set_adapter_to_index(adapter_to_index: Dict[str, int]):
    global ADAPTER_TO_INDEX
    ADAPTER_TO_INDEX = adapter_to_index
Nicolas Patry's avatar
Nicolas Patry committed
65
66
67
68
69


def get_adapter_to_index():
    global ADAPTER_TO_INDEX
    return ADAPTER_TO_INDEX