Unverified Commit 7a1ba585 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(docker): fix docker image dependencies (#187)

parent 379c5c4d
...@@ -79,8 +79,8 @@ jobs: ...@@ -79,8 +79,8 @@ jobs:
flavor: | flavor: |
latest=auto latest=auto
images: | images: |
ghcr.io/huggingface/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: | tags: |
type=semver,pattern={{version}} type=semver,pattern={{version}}
......
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
WORKDIR /usr/src WORKDIR /usr/src
...@@ -27,51 +28,127 @@ COPY router router ...@@ -27,51 +28,127 @@ COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --release
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as base # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git && \
rm -rf /var/lib/apt/lists/*
# Install conda
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# CUDA kernels builder image
FROM pytorch-install as kernel-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \
&& rm -rf /var/lib/apt/lists/*
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN make build-flash-attention
# Build Transformers CUDA kernels
FROM kernel-builder as transformers-builder
ENV LANG=C.UTF-8 \ WORKDIR /usr/src
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \ COPY server/Makefile-transformers Makefile
HUGGINGFACE_HUB_CACHE=/data \
# Build specific version of transformers
RUN BUILD_EXTENSIONS="True" make build-transformers
# Text Generation Inference base image
FROM debian:bullseye-slim as base
# Conda env
ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \ MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \ QUANTIZE=false \
NUM_SHARD=1 \ NUM_SHARD=1 \
PORT=80 \ PORT=80
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/conda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
PATH=$PATH:/opt/conda/bin:/usr/local/cuda/bin
RUN apt-get update && apt-get install -y git curl libssl-dev ninja-build && rm -rf /var/lib/apt/lists/* LABEL com.nvidia.volumes.needed="nvidia_driver"
RUN cd ~ && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" \
chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
WORKDIR /usr/src WORKDIR /usr/src
# Install torch RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
RUN pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir libssl-dev \
ca-certificates \
make \
&& rm -rf /var/lib/apt/lists/*
# Install specific version of flash attention # Copy conda with PyTorch installed
COPY server/Makefile-flash-att server/Makefile COPY --from=pytorch-install /opt/conda /opt/conda
RUN cd server && make install-flash-attention
# Install specific version of transformers # Copy build artifacts from flash attention builder
COPY server/Makefile-transformers server/Makefile COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY server/Makefile server/Makefile # Copy build artifacts from transformers builder
COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers
COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
# Install transformers dependencies
RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements.txt && \
pip install ".[bnb]" --no-cache-dir pip install ".[bnb]" --no-cache-dir
# Install router # Install router
......
...@@ -8,6 +8,7 @@ description = "Text Generation Launcher" ...@@ -8,6 +8,7 @@ description = "Text Generation Launcher"
[dependencies] [dependencies]
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] } ctrlc = { version = "3.2.5", features = ["termination"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93" serde_json = "1.0.93"
subprocess = "0.2.9" subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
...@@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] } ...@@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"
reqwest = { version = "0.11.14", features = ["blocking", "json"] } reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.152", features = ["derive"] }
use clap::Parser; use clap::Parser;
use serde_json::Value; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
...@@ -244,11 +244,8 @@ fn main() -> ExitCode { ...@@ -244,11 +244,8 @@ fn main() -> ExitCode {
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() { for line in stdout.lines() {
// Parse loguru logs // Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) { if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
if let Some(text) = value.get("text") { log.trace();
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
} }
} }
}); });
...@@ -525,7 +522,7 @@ fn shard_manager( ...@@ -525,7 +522,7 @@ fn shard_manager(
"--uds-path".to_string(), "--uds-path".to_string(),
uds_path, uds_path,
"--logger-level".to_string(), "--logger-level".to_string(),
"ERROR".to_string(), "INFO".to_string(),
"--json-output".to_string(), "--json-output".to_string(),
]; ];
...@@ -643,11 +640,8 @@ fn shard_manager( ...@@ -643,11 +640,8 @@ fn shard_manager(
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() { for line in stdout.lines() {
// Parse loguru logs // Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) { if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
if let Some(text) = value.get("text") { log.trace();
// Format escaped newlines
tracing::error!("{}", text.to_string().replace("\\n", "\n"));
}
} }
} }
}); });
...@@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option<usize> { ...@@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option<usize> {
} }
None None
} }
#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
Trace,
Debug,
Info,
Success,
Warning,
Error,
Critical,
}
#[derive(Deserialize)]
struct PythonLogLevel {
name: PythonLogLevelEnum,
}
#[derive(Deserialize)]
struct PythonLogRecord {
level: PythonLogLevel,
}
#[derive(Deserialize)]
struct PythonLogMessage {
text: String,
record: PythonLogRecord,
}
impl PythonLogMessage {
fn trace(&self) {
match self.record.level.name {
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
}
}
}
...@@ -16,6 +16,7 @@ install-torch: ...@@ -16,6 +16,7 @@ install-torch:
install: gen-server install-torch install-transformers install: gen-server install-torch install-transformers
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements.txt
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
......
...@@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0" ...@@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0"
[[package]] [[package]]
name = "bitsandbytes" name = "bitsandbytes"
version = "0.35.4" version = "0.38.1"
description = "8-bit optimizers and matrix multiplication routines." description = "8-bit optimizers and matrix multiplication routines."
category = "main" category = "main"
optional = false optional = false
...@@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] ...@@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
[[package]] [[package]]
name = "grpc-interceptor" name = "grpc-interceptor"
version = "0.15.0" version = "0.15.1"
description = "Simplifies gRPC interceptors" description = "Simplifies gRPC interceptors"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.6.1,<4.0.0" python-versions = ">=3.7,<4.0"
[package.dependencies] [package.dependencies]
grpcio = ">=1.32.0,<2.0.0" grpcio = ">=1.49.1,<2.0.0"
[package.extras] [package.extras]
testing = ["protobuf (>=3.6.0)"] testing = ["protobuf (>=4.21.9)"]
[[package]] [[package]]
name = "grpcio" name = "grpcio"
...@@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] ...@@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.3.0" version = "7.3.1"
description = "pytest: simple powerful testing with Python" description = "pytest: simple powerful testing with Python"
category = "dev" category = "dev"
optional = false optional = false
...@@ -833,7 +833,7 @@ bnb = ["bitsandbytes"] ...@@ -833,7 +833,7 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "6141d488429e0ab579028036e8e4cbc54f583b48214cb4a6be066bb7ce5154db" content-hash = "e05491a03938b79a71b498f2759169f5a41181084158fde5993e7dcb25292cb0"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
...@@ -845,8 +845,8 @@ backoff = [ ...@@ -845,8 +845,8 @@ backoff = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
] ]
bitsandbytes = [ bitsandbytes = [
{file = "bitsandbytes-0.35.4-py3-none-any.whl", hash = "sha256:201f168538ccfbd7594568a2f86c149cec8352782301076a15a783695ecec7fb"}, {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"},
{file = "bitsandbytes-0.35.4.tar.gz", hash = "sha256:b23db6b91cd73cb14faf9841a66bffa5c1722f9b8b57039ef2fb461ac22dd2a6"}, {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"},
] ]
certifi = [ certifi = [
{file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"}, {file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"},
...@@ -973,8 +973,8 @@ googleapis-common-protos = [ ...@@ -973,8 +973,8 @@ googleapis-common-protos = [
{file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"}, {file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"},
] ]
grpc-interceptor = [ grpc-interceptor = [
{file = "grpc-interceptor-0.15.0.tar.gz", hash = "sha256:5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e"}, {file = "grpc-interceptor-0.15.1.tar.gz", hash = "sha256:3efadbc9aead272ac7a360c75c4bd96233094c9a5192dbb51c6156246bd64ba0"},
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"}, {file = "grpc_interceptor-0.15.1-py3-none-any.whl", hash = "sha256:1cc52c34b0d7ff34512fb7780742ecda37bf3caa18ecc5f33f09b4f74e96b276"},
] ]
grpcio = [ grpcio = [
{file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"}, {file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"},
...@@ -1329,8 +1329,8 @@ psutil = [ ...@@ -1329,8 +1329,8 @@ psutil = [
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"}, {file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
] ]
pytest = [ pytest = [
{file = "pytest-7.3.0-py3-none-any.whl", hash = "sha256:933051fa1bfbd38a21e73c3960cebdad4cf59483ddba7696c48509727e17f201"}, {file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"},
{file = "pytest-7.3.0.tar.gz", hash = "sha256:58ecc27ebf0ea643ebfdf7fb1249335da761a00c9f955bcd922349bcb68ee57d"}, {file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"},
] ]
PyYAML = [ PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
......
...@@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1" ...@@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.6.1" typer = "^0.6.1"
accelerate = "^0.15.0" accelerate = "^0.15.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.38.1"
safetensors = "^0.2.4" safetensors = "^0.2.4"
loguru = "^0.6.0" loguru = "^0.6.0"
opentelemetry-api = "^1.15.0" opentelemetry-api = "^1.15.0"
......
This diff is collapsed.
...@@ -6,8 +6,6 @@ from pathlib import Path ...@@ -6,8 +6,6 @@ from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer() app = typer.Typer()
...@@ -48,6 +46,11 @@ def serve( ...@@ -48,6 +46,11 @@ def serve(
backtrace=True, backtrace=True,
diagnose=False, diagnose=False,
) )
# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing
# Setup OpenTelemetry distributed tracing # Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
...@@ -75,6 +78,9 @@ def download_weights( ...@@ -75,6 +78,9 @@ def download_weights(
diagnose=False, diagnose=False,
) )
# Import here after the logger is added to log potential import exceptions
from text_generation_server import utils
# Test if files were already download # Test if files were already download
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
......
...@@ -26,7 +26,7 @@ try: ...@@ -26,7 +26,7 @@ try:
FLASH_ATTENTION = torch.cuda.is_available() FLASH_ATTENTION = torch.cuda.is_available()
except ImportError: except ImportError:
logger.exception("Could not import Flash Attention enabled models") logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
FLASH_ATTENTION = False FLASH_ATTENTION = False
__all__ = [ __all__ = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment