Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8a8f4341
Unverified
Commit
8a8f4341
authored
May 12, 2023
by
OlivierDehaene
Committed by
GitHub
May 12, 2023
Browse files
chore(docker): use nvidia base image (#318)
parent
76a48cd3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
19 deletions
+21
-19
Dockerfile
Dockerfile
+1
-12
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+19
-7
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+1
-0
No files found.
Dockerfile
View file @
8a8f4341
...
...
@@ -108,7 +108,7 @@ COPY server/Makefile-transformers Makefile
RUN
BUILD_EXTENSIONS
=
"True"
make build-transformers
# Text Generation Inference base image
FROM
debian:bullseye-slim
as base
FROM
nvidia/cuda:11.8.0-base-ubuntu22.04
as base
# Conda env
ENV
PATH=/opt/conda/bin:$PATH \
...
...
@@ -122,17 +122,6 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
NUM_SHARD=1 \
PORT=80
# NVIDIA env vars
ENV
NVIDIA_VISIBLE_DEVICES all
ENV
NVIDIA_DRIVER_CAPABILITIES compute,utility
# Required for nvidia-docker v1
RUN
/bin/bash
-c
echo
"/usr/local/nvidia/lib"
>>
/etc/ld.so.conf.d/nvidia.conf
&&
\
echo
"/usr/local/nvidia/lib64"
>>
/etc/ld.so.conf.d/nvidia.conf
ENV
LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV
PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
LABEL
com.nvidia.volumes.needed="nvidia_driver"
WORKDIR
/usr/src
RUN
apt-get update
&&
DEBIAN_FRONTEND
=
noninteractive apt-get
install
-y
--no-install-recommends
\
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
8a8f4341
...
...
@@ -585,13 +585,25 @@ class FlashSantacoderForCausalLM(nn.Module):
if
self
.
transformer
.
tp_embeddings
:
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
transformer
.
tp_world_size
)
]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
transformer
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
if
logits
.
shape
[
0
]
==
1
:
# Fast path when batch size is 1
world_logits
=
logits
.
new_empty
(
(
logits
.
shape
[
1
]
*
self
.
transformer
.
tp_world_size
)
)
torch
.
distributed
.
all_gather_into_tensor
(
world_logits
,
logits
.
view
(
-
1
),
group
=
self
.
transformer
.
process_group
)
world_logits
=
world_logits
.
view
(
1
,
-
1
)
else
:
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
transformer
.
tp_world_size
)
]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
transformer
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
8a8f4341
...
...
@@ -217,6 +217,7 @@ class FlashSantacoderSharded(FlashSantacoder):
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
decode_buffer
=
1
,
)
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment