Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
7a1ba585
Unverified
Commit
7a1ba585
authored
Apr 17, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 17, 2023
Browse files
fix(docker): fix docker image dependencies (#187)
parent
379c5c4d
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
875 additions
and
56 deletions
+875
-56
.github/workflows/build.yaml
.github/workflows/build.yaml
+1
-1
Dockerfile
Dockerfile
+102
-25
launcher/Cargo.toml
launcher/Cargo.toml
+1
-1
launcher/src/main.rs
launcher/src/main.rs
+48
-12
server/Makefile
server/Makefile
+1
-0
server/poetry.lock
server/poetry.lock
+13
-13
server/pyproject.toml
server/pyproject.toml
+1
-1
server/requirements.txt
server/requirements.txt
+699
-0
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+8
-2
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+1
-1
No files found.
.github/workflows/build.yaml
View file @
7a1ba585
...
@@ -79,8 +79,8 @@ jobs:
...
@@ -79,8 +79,8 @@ jobs:
flavor
:
|
flavor
:
|
latest=auto
latest=auto
images
:
|
images
:
|
ghcr.io/huggingface/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags
:
|
tags
:
|
type=semver,pattern={{version}}
type=semver,pattern={{version}}
...
...
Dockerfile
View file @
7a1ba585
# Rust builder
FROM
lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
FROM
lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
WORKDIR
/usr/src
WORKDIR
/usr/src
...
@@ -27,51 +28,127 @@ COPY router router
...
@@ -27,51 +28,127 @@ COPY router router
COPY
launcher launcher
COPY
launcher launcher
RUN
cargo build
--release
RUN
cargo build
--release
FROM
nvidia/cuda:11.8.0-devel-ubuntu22.04 as base
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM
debian:bullseye-slim as pytorch-install
ARG
PYTORCH_VERSION=2.0.0
ARG
PYTHON_VERSION=3.9
ARG
CUDA_VERSION=11.8
ARG
MAMBA_VERSION=23.1.0-1
ARG
MAMBA_VERSION=23.1.0-1
ARG
CUDA_CHANNEL=nvidia
ARG
INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG
TARGETPLATFORM
ENV
PATH /opt/conda/bin:$PATH
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
build-essential
\
ca-certificates
\
ccache
\
curl
\
git
&&
\
rm
-rf
/var/lib/apt/lists/
*
# Install conda
# translating Docker's TARGETPLATFORM into mamba arches
RUN case
${
TARGETPLATFORM
}
in
\
"linux/arm64"
)
MAMBA_ARCH
=
aarch64
;;
\
*
)
MAMBA_ARCH
=
x86_64
;;
\
esac
&&
\
curl
-fsSL
-v
-o
~/mambaforge.sh
-O
"https://github.com/conda-forge/miniforge/releases/download/
${
MAMBA_VERSION
}
/Mambaforge-
${
MAMBA_VERSION
}
-Linux-
${
MAMBA_ARCH
}
.sh"
RUN
chmod
+x ~/mambaforge.sh
&&
\
bash ~/mambaforge.sh
-b
-p
/opt/conda
&&
\
rm
~/mambaforge.sh
# Install pytorch
# On arm64 we exit with an error code
RUN case
${
TARGETPLATFORM
}
in
\
"linux/arm64"
)
exit
1
;;
\
*
)
/opt/conda/bin/conda update
-y
conda
&&
\
/opt/conda/bin/conda
install
-c
"
${
INSTALL_CHANNEL
}
"
-c
"
${
CUDA_CHANNEL
}
"
-y
"python=
${
PYTHON_VERSION
}
"
pytorch
==
$PYTORCH_VERSION
"pytorch-cuda=
$(
echo
$CUDA_VERSION
|
cut
-d
'.'
-f
1-2
)
"
;;
\
esac
&&
\
/opt/conda/bin/conda clean
-ya
# CUDA kernels builder image
FROM
pytorch-install as kernel-builder
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
ninja-build
\
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
/opt/conda/bin/conda
install
-c
"nvidia/label/cuda-11.8.0"
cuda
==
11.8
&&
\
/opt/conda/bin/conda clean
-ya
# Build Flash Attention CUDA kernels
FROM
kernel-builder as flash-att-builder
WORKDIR
/usr/src
COPY
server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN
make build-flash-attention
# Build Transformers CUDA kernels
FROM
kernel-builder as transformers-builder
ENV
LANG=C.UTF-8 \
WORKDIR
/usr/src
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
COPY
server/Makefile-transformers Makefile
HUGGINGFACE_HUB_CACHE=/data \
# Build specific version of transformers
RUN
BUILD_EXTENSIONS
=
"True"
make build-transformers
# Text Generation Inference base image
FROM
debian:bullseye-slim as base
# Conda env
ENV
PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
# Text Generation Inference base env
ENV
HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
QUANTIZE=false \
NUM_SHARD=1 \
NUM_SHARD=1 \
PORT=80 \
PORT=80
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/conda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
PATH=$PATH:/opt/conda/bin:/usr/local/cuda/bin
RUN
apt-get update
&&
apt-get
install
-y
git curl libssl-dev ninja-build
&&
rm
-rf
/var/lib/apt/lists/
*
LABEL
com.nvidia.volumes.needed="nvidia_driver"
RUN
cd
~
&&
\
curl
-fsSL
-v
-o
~/mambaforge.sh
-O
"https://github.com/conda-forge/miniforge/releases/download/
${
MAMBA_VERSION
}
/Mambaforge-
${
MAMBA_VERSION
}
-Linux-x86_64.sh"
\
chmod
+x ~/mambaforge.sh
&&
\
bash ~/mambaforge.sh
-b
-p
/opt/conda
&&
\
rm
~/mambaforge.sh
WORKDIR
/usr/src
WORKDIR
/usr/src
# Install torch
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
RUN
pip
install
torch
==
2.0.0
--extra-index-url
https://download.pytorch.org/whl/cu118
--no-cache-dir
libssl-dev
\
ca-certificates
\
make
\
&&
rm
-rf
/var/lib/apt/lists/
*
# Install specific version of flash attention
# Copy conda with PyTorch installed
COPY
server/Makefile-flash-att server/Makefile
COPY
--from=pytorch-install /opt/conda /opt/conda
RUN
cd
server
&&
make install-flash-attention
# Install specific version of transformers
# Copy build artifacts from flash attention builder
COPY
server/Makefile-transformers server/Makefile
COPY
--from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
RUN
cd
server
&&
BUILD_EXTENSIONS
=
"True"
make install-transformers
COPY
--from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY
--from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY
server/Makefile server/Makefile
# Copy build artifacts from transformers builder
COPY
--from=transformers-builder /usr/src/transformers /usr/src/transformers
COPY
--from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
# Install transformers dependencies
RUN
cd
/usr/src/transformers
&&
pip
install
-e
.
--no-cache-dir
&&
pip
install
einops
--no-cache-dir
# Install server
# Install server
COPY
proto proto
COPY
proto proto
COPY
server server
COPY
server server
COPY
server/Makefile server/Makefile
RUN
cd
server
&&
\
RUN
cd
server
&&
\
make gen-server
&&
\
make gen-server
&&
\
pip
install
-r
requirements.txt
&&
\
pip
install
".[bnb]"
--no-cache-dir
pip
install
".[bnb]"
--no-cache-dir
# Install router
# Install router
...
...
launcher/Cargo.toml
View file @
7a1ba585
...
@@ -8,6 +8,7 @@ description = "Text Generation Launcher"
...
@@ -8,6 +8,7 @@ description = "Text Generation Launcher"
[dependencies]
[dependencies]
clap
=
{
version
=
"4.1.4"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.1.4"
,
features
=
[
"derive"
,
"env"
]
}
ctrlc
=
{
version
=
"3.2.5"
,
features
=
["termination"]
}
ctrlc
=
{
version
=
"3.2.5"
,
features
=
["termination"]
}
serde
=
{
version
=
"1.0.152"
,
features
=
["derive"]
}
serde_json
=
"1.0.93"
serde_json
=
"1.0.93"
subprocess
=
"0.2.9"
subprocess
=
"0.2.9"
tracing
=
"0.1.37"
tracing
=
"0.1.37"
...
@@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
...
@@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
[dev-dependencies]
float_eq
=
"1.0.1"
float_eq
=
"1.0.1"
reqwest
=
{
version
=
"0.11.14"
,
features
=
[
"blocking"
,
"json"
]
}
reqwest
=
{
version
=
"0.11.14"
,
features
=
[
"blocking"
,
"json"
]
}
serde
=
{
version
=
"1.0.152"
,
features
=
["derive"]
}
launcher/src/main.rs
View file @
7a1ba585
use
clap
::
Parser
;
use
clap
::
Parser
;
use
serde
_json
::
Valu
e
;
use
serde
::
Deserializ
e
;
use
std
::
env
;
use
std
::
env
;
use
std
::
ffi
::
OsString
;
use
std
::
ffi
::
OsString
;
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
...
@@ -244,11 +244,8 @@ fn main() -> ExitCode {
...
@@ -244,11 +244,8 @@ fn main() -> ExitCode {
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
for
line
in
stdout
.lines
()
{
for
line
in
stdout
.lines
()
{
// Parse loguru logs
// Parse loguru logs
if
let
Ok
(
value
)
=
serde_json
::
from_str
::
<
Value
>
(
&
line
.unwrap
())
{
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
if
let
Some
(
text
)
=
value
.get
(
"text"
)
{
log
.trace
();
// Format escaped newlines
tracing
::
info!
(
"{}"
,
text
.to_string
()
.replace
(
"
\\
n"
,
""
));
}
}
}
}
}
});
});
...
@@ -525,7 +522,7 @@ fn shard_manager(
...
@@ -525,7 +522,7 @@ fn shard_manager(
"--uds-path"
.to_string
(),
"--uds-path"
.to_string
(),
uds_path
,
uds_path
,
"--logger-level"
.to_string
(),
"--logger-level"
.to_string
(),
"
ERROR
"
.to_string
(),
"
INFO
"
.to_string
(),
"--json-output"
.to_string
(),
"--json-output"
.to_string
(),
];
];
...
@@ -643,11 +640,8 @@ fn shard_manager(
...
@@ -643,11 +640,8 @@ fn shard_manager(
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"shard-manager"
,
rank
=
rank
)
.entered
();
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"shard-manager"
,
rank
=
rank
)
.entered
();
for
line
in
stdout
.lines
()
{
for
line
in
stdout
.lines
()
{
// Parse loguru logs
// Parse loguru logs
if
let
Ok
(
value
)
=
serde_json
::
from_str
::
<
Value
>
(
&
line
.unwrap
())
{
if
let
Ok
(
log
)
=
serde_json
::
from_str
::
<
PythonLogMessage
>
(
&
line
.unwrap
())
{
if
let
Some
(
text
)
=
value
.get
(
"text"
)
{
log
.trace
();
// Format escaped newlines
tracing
::
error!
(
"{}"
,
text
.to_string
()
.replace
(
"
\\
n"
,
"
\n
"
));
}
}
}
}
}
});
});
...
@@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option<usize> {
...
@@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option<usize> {
}
}
None
None
}
}
#[derive(Deserialize)]
#[serde(rename_all
=
"UPPERCASE"
)]
enum
PythonLogLevelEnum
{
Trace
,
Debug
,
Info
,
Success
,
Warning
,
Error
,
Critical
,
}
#[derive(Deserialize)]
struct
PythonLogLevel
{
name
:
PythonLogLevelEnum
,
}
#[derive(Deserialize)]
struct
PythonLogRecord
{
level
:
PythonLogLevel
,
}
#[derive(Deserialize)]
struct
PythonLogMessage
{
text
:
String
,
record
:
PythonLogRecord
,
}
impl
PythonLogMessage
{
fn
trace
(
&
self
)
{
match
self
.record.level.name
{
PythonLogLevelEnum
::
Trace
=>
tracing
::
trace!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Debug
=>
tracing
::
debug!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Info
=>
tracing
::
info!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Success
=>
tracing
::
info!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Warning
=>
tracing
::
warn!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Error
=>
tracing
::
error!
(
"{}"
,
self
.text
),
PythonLogLevelEnum
::
Critical
=>
tracing
::
error!
(
"{}"
,
self
.text
),
}
}
}
server/Makefile
View file @
7a1ba585
...
@@ -16,6 +16,7 @@ install-torch:
...
@@ -16,6 +16,7 @@ install-torch:
install
:
gen-server install-torch install-transformers
install
:
gen-server install-torch install-transformers
pip
install
pip
--upgrade
pip
install
pip
--upgrade
pip
install
-r
requirements.txt
pip
install
-e
.
--no-cache-dir
pip
install
-e
.
--no-cache-dir
run-dev
:
run-dev
:
...
...
server/poetry.lock
View file @
7a1ba585
...
@@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0"
...
@@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0"
[[package]]
[[package]]
name = "bitsandbytes"
name = "bitsandbytes"
version = "0.3
5.4
"
version = "0.3
8.1
"
description = "8-bit optimizers and matrix multiplication routines."
description = "8-bit optimizers and matrix multiplication routines."
category = "main"
category = "main"
optional = false
optional = false
...
@@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
...
@@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
[[package]]
[[package]]
name = "grpc-interceptor"
name = "grpc-interceptor"
version = "0.15.
0
"
version = "0.15.
1
"
description = "Simplifies gRPC interceptors"
description = "Simplifies gRPC interceptors"
category = "main"
category = "main"
optional = false
optional = false
python-versions = ">=3.
6.1
,<4.0
.0
"
python-versions = ">=3.
7
,<4.0"
[package.dependencies]
[package.dependencies]
grpcio = ">=1.
32.0
,<2.0.0"
grpcio = ">=1.
49.1
,<2.0.0"
[package.extras]
[package.extras]
testing = ["protobuf (>=
3.6.0
)"]
testing = ["protobuf (>=
4.21.9
)"]
[[package]]
[[package]]
name = "grpcio"
name = "grpcio"
...
@@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
...
@@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
[[package]]
name = "pytest"
name = "pytest"
version = "7.3.
0
"
version = "7.3.
1
"
description = "pytest: simple powerful testing with Python"
description = "pytest: simple powerful testing with Python"
category = "dev"
category = "dev"
optional = false
optional = false
...
@@ -833,7 +833,7 @@ bnb = ["bitsandbytes"]
...
@@ -833,7 +833,7 @@ bnb = ["bitsandbytes"]
[metadata]
[metadata]
lock-version = "1.1"
lock-version = "1.1"
python-versions = "^3.9"
python-versions = "^3.9"
content-hash = "
6141d488429e0ab579028036e8e4cbc54f583b48214cb4a6be066bb7ce5154db
"
content-hash = "
e05491a03938b79a71b498f2759169f5a41181084158fde5993e7dcb25292cb0
"
[metadata.files]
[metadata.files]
accelerate = [
accelerate = [
...
@@ -845,8 +845,8 @@ backoff = [
...
@@ -845,8 +845,8 @@ backoff = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
]
]
bitsandbytes = [
bitsandbytes = [
{file = "bitsandbytes-0.3
5.4
-py3-none-any.whl", hash = "sha256:
201f168538ccfbd7594568a2f86c149cec8352782301076a15a783695ecec7fb
"},
{file = "bitsandbytes-0.3
8.1
-py3-none-any.whl", hash = "sha256:
5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a
"},
{file = "bitsandbytes-0.3
5.4
.tar.gz", hash = "sha256:b
23db6b91cd73cb14faf9841a66bffa5c1722f9b8b57039ef2fb461ac22dd2a6
"},
{file = "bitsandbytes-0.3
8.1
.tar.gz", hash = "sha256:b
a95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce
"},
]
]
certifi = [
certifi = [
{file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"},
{file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"},
...
@@ -973,8 +973,8 @@ googleapis-common-protos = [
...
@@ -973,8 +973,8 @@ googleapis-common-protos = [
{file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"},
{file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"},
]
]
grpc-interceptor = [
grpc-interceptor = [
{file = "grpc-interceptor-0.15.
0
.tar.gz", hash = "sha256:
5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e
"},
{file = "grpc-interceptor-0.15.
1
.tar.gz", hash = "sha256:
3efadbc9aead272ac7a360c75c4bd96233094c9a5192dbb51c6156246bd64ba0
"},
{file = "grpc_interceptor-0.15.
0
-py3-none-any.whl", hash = "sha256:
63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e
"},
{file = "grpc_interceptor-0.15.
1
-py3-none-any.whl", hash = "sha256:
1cc52c34b0d7ff34512fb7780742ecda37bf3caa18ecc5f33f09b4f74e96b276
"},
]
]
grpcio = [
grpcio = [
{file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"},
{file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"},
...
@@ -1329,8 +1329,8 @@ psutil = [
...
@@ -1329,8 +1329,8 @@ psutil = [
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
]
]
pytest = [
pytest = [
{file = "pytest-7.3.
0
-py3-none-any.whl", hash = "sha256:
933051fa1bfbd38a21e73c3960cebdad4cf59483ddba7696c48509727e17f201
"},
{file = "pytest-7.3.
1
-py3-none-any.whl", hash = "sha256:
3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362
"},
{file = "pytest-7.3.
0
.tar.gz", hash = "sha256:
58ecc27ebf0ea643ebfdf7fb1249335da761a00c9f955bcd922349bcb68ee57d
"},
{file = "pytest-7.3.
1
.tar.gz", hash = "sha256:
434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3
"},
]
]
PyYAML = [
PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
...
...
server/pyproject.toml
View file @
7a1ba585
...
@@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
...
@@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor
=
"^0.15.0"
grpc-interceptor
=
"^0.15.0"
typer
=
"^0.6.1"
typer
=
"^0.6.1"
accelerate
=
"^0.15.0"
accelerate
=
"^0.15.0"
bitsandbytes
=
"^0.3
5
.1"
bitsandbytes
=
"^0.3
8
.1"
safetensors
=
"^0.2.4"
safetensors
=
"^0.2.4"
loguru
=
"^0.6.0"
loguru
=
"^0.6.0"
opentelemetry-api
=
"^1.15.0"
opentelemetry-api
=
"^1.15.0"
...
...
server/requirements.txt
0 → 100644
View file @
7a1ba585
This diff is collapsed.
Click to expand it.
server/text_generation_server/cli.py
View file @
7a1ba585
...
@@ -6,8 +6,6 @@ from pathlib import Path
...
@@ -6,8 +6,6 @@ from pathlib import Path
from
loguru
import
logger
from
loguru
import
logger
from
typing
import
Optional
from
typing
import
Optional
from
text_generation_server
import
server
,
utils
from
text_generation_server.tracing
import
setup_tracing
app
=
typer
.
Typer
()
app
=
typer
.
Typer
()
...
@@ -48,6 +46,11 @@ def serve(
...
@@ -48,6 +46,11 @@ def serve(
backtrace
=
True
,
backtrace
=
True
,
diagnose
=
False
,
diagnose
=
False
,
)
)
# Import here after the logger is added to log potential import exceptions
from
text_generation_server
import
server
from
text_generation_server.tracing
import
setup_tracing
# Setup OpenTelemetry distributed tracing
# Setup OpenTelemetry distributed tracing
if
otlp_endpoint
is
not
None
:
if
otlp_endpoint
is
not
None
:
setup_tracing
(
shard
=
os
.
getenv
(
"RANK"
,
0
),
otlp_endpoint
=
otlp_endpoint
)
setup_tracing
(
shard
=
os
.
getenv
(
"RANK"
,
0
),
otlp_endpoint
=
otlp_endpoint
)
...
@@ -75,6 +78,9 @@ def download_weights(
...
@@ -75,6 +78,9 @@ def download_weights(
diagnose
=
False
,
diagnose
=
False
,
)
)
# Import here after the logger is added to log potential import exceptions
from
text_generation_server
import
utils
# Test if files were already download
# Test if files were already download
try
:
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
utils
.
weight_files
(
model_id
,
revision
,
extension
)
...
...
server/text_generation_server/models/__init__.py
View file @
7a1ba585
...
@@ -26,7 +26,7 @@ try:
...
@@ -26,7 +26,7 @@ try:
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
except
ImportError
:
except
ImportError
:
logger
.
exception
(
"Could not import Flash Attention enabled models"
)
logger
.
opt
(
exception
=
True
).
warning
(
"Could not import Flash Attention enabled models"
)
FLASH_ATTENTION
=
False
FLASH_ATTENTION
=
False
__all__
=
[
__all__
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment