Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
bf99afe9
Commit
bf99afe9
authored
Oct 14, 2022
by
Olivier Dehaene
Browse files
feat: Docker image
parent
39df4d99
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
266 additions
and
181 deletions
+266
-181
.dockerignore
.dockerignore
+1
-0
Dockerfile
Dockerfile
+59
-0
README.md
README.md
+0
-2
router/rust-toolchain.toml
router/rust-toolchain.toml
+3
-0
router/src/batcher.rs
router/src/batcher.rs
+6
-1
router/src/main.rs
router/src/main.rs
+2
-2
router/src/server.rs
router/src/server.rs
+44
-18
run.sh
run.sh
+21
-0
server/bloom_inference/model.py
server/bloom_inference/model.py
+5
-2
server/bloom_inference/prepare_weights.py
server/bloom_inference/prepare_weights.py
+111
-53
server/bloom_inference/shard_model.py
server/bloom_inference/shard_model.py
+0
-102
server/poetry.lock
server/poetry.lock
+13
-1
server/pyproject.toml
server/pyproject.toml
+1
-0
No files found.
.dockerignore
0 → 100644
View file @
bf99afe9
router/target
\ No newline at end of file
Dockerfile
0 → 100644
View file @
bf99afe9
FROM
rust:1.64 as builder
WORKDIR
/usr/src
COPY
proto proto
COPY
router router
WORKDIR
/usr/src/router
RUN
cargo
install
--path
.
FROM
nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/var/azureml-model \
MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
SHELL
["/bin/bash", "-c"]
RUN
apt-get update
&&
apt-get
install
-y
unzip wget libssl-dev
&&
rm
-rf
/var/lib/apt/lists/
*
RUN
cd
~
&&
\
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
&&
\
chmod
+x Miniconda3-latest-Linux-x86_64.sh
&&
\
bash ./Miniconda3-latest-Linux-x86_64.sh
-bf
-p
/opt/miniconda
&&
\
conda create
-n
text-generation
python
=
3.9
-y
# Install specific version of torch
RUN
/opt/miniconda/envs/text-generation/bin/pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu116
--no-cache-dir
# Install specific version of transformers
RUN
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
&&
\
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
&&
\
rm
46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
&&
\
cd
transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921
&&
\
/opt/miniconda/envs/text-generation/bin/python setup.py
install
WORKDIR
/usr/src
# Install server
COPY
server server
RUN
cd
server
&&
\
/opt/miniconda/envs/text-generation/bin/pip
install
.
--no-cache-dir
# Install router
COPY
--from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
COPY
run.sh .
RUN
chmod
+x run.sh
CMD
["./run.sh"]
\ No newline at end of file
README.md
View file @
bf99afe9
...
@@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
...
@@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO:
## TODO:
-
[ ] Improve model download
-
Store "shardable" layers separately and layer by layer
-
[ ] Add batching args to router CLI
-
[ ] Add batching args to router CLI
-
[ ] Add docstrings + comments everywhere as the codebase is fairly complicated
-
[ ] Add docstrings + comments everywhere as the codebase is fairly complicated
-
[ ] Add tests
-
[ ] Add tests
...
...
router/rust-toolchain.toml
0 → 100644
View file @
bf99afe9
[toolchain]
channel
=
"1.64.0"
components
=
[
"rustfmt"
,
"clippy"
]
\ No newline at end of file
router/src/batcher.rs
View file @
bf99afe9
...
@@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
...
@@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
cached_batch
=
match
batch_size
{
cached_batch
=
match
batch_size
{
size
if
size
>
16
=>
{
size
if
size
>
16
=>
{
wrap_future
(
client
.generate_until_finished_with_cache
(
batches
),
request_ids
,
&
db
)
.await
wrap_future
(
client
.generate_until_finished_with_cache
(
batches
),
request_ids
,
&
db
,
)
.await
}
}
_
=>
wrap_future
(
client
.generate_with_cache
(
batches
),
request_ids
,
&
db
)
.await
,
_
=>
wrap_future
(
client
.generate_with_cache
(
batches
),
request_ids
,
&
db
)
.await
,
};
};
...
...
router/src/main.rs
View file @
bf99afe9
use
std
::
net
::
SocketAddr
;
use
bloom_inference_client
::
ShardedClient
;
use
bloom_inference_client
::
ShardedClient
;
use
std
::
net
::
SocketAddr
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
...
@@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
.expect
(
"Unable to clear cache"
);
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
let
addr
=
SocketAddr
::
from
(([
127
,
0
,
0
,
1
],
3000
));
let
addr
=
SocketAddr
::
from
(([
0
,
0
,
0
,
0
],
3000
));
server
::
run
(
sharded_client
,
tokenizer
,
addr
)
.await
;
server
::
run
(
sharded_client
,
tokenizer
,
addr
)
.await
;
Ok
(())
Ok
(())
...
...
router/src/server.rs
View file @
bf99afe9
use
std
::
net
::
SocketAddr
;
use
crate
::{
Batcher
,
ShardedClient
,
Validation
};
use
axum
::{
Router
,
Json
};
use
axum
::
http
::
StatusCode
;
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::
StatusCode
;
use
axum
::
routing
::
post
;
use
axum
::
routing
::
post
;
use
crate
::{
Batcher
,
ShardedClient
,
Validation
};
use
axum
::{
Json
,
Router
};
use
serde
::
Deserialize
;
use
serde
::
Deserialize
;
use
std
::
net
::
SocketAddr
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
tracing
::
instrument
;
use
tracing
::
instrument
;
...
@@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
...
@@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
pub
parameters
:
GenerateParameters
,
}
}
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
liveness
(
state
:
Extension
<
ServerState
>
)
->
Result
<
(),
StatusCode
>
{
let
output
=
state
.infer
.infer
(
1
,
GenerateRequest
{
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
temperature
:
1.0
,
top_k
:
0
,
top_p
:
1.0
,
do_sample
:
false
,
max_new_tokens
:
1
,
},
},
)
.await
;
match
output
{
Ok
(
_
)
=>
Ok
(()),
Err
(
_
)
=>
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
),
}
}
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(skip(state),
fields(time,
time_per_token))]
async
fn
generate
(
async
fn
generate
(
state
:
Extension
<
ServerState
>
,
state
:
Extension
<
ServerState
>
,
...
@@ -67,14 +92,16 @@ async fn generate(
...
@@ -67,14 +92,16 @@ async fn generate(
)
->
Result
<
Json
<
serde_json
::
Value
>
,
StatusCode
>
{
)
->
Result
<
Json
<
serde_json
::
Value
>
,
StatusCode
>
{
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
let
(
input_length
,
validated_request
)
=
match
state
.validation
let
(
input_length
,
validated_request
)
=
match
state
.validation
.validate
(
GenerateRequest
{
.validate
(
GenerateRequest
{
inputs
:
req
.inputs
.clone
(),
inputs
:
req
.inputs
.clone
(),
parameters
:
req
.parameters
.clone
(),
parameters
:
req
.parameters
.clone
(),
})
})
.await
{
.await
{
Ok
(
result
)
=>
result
,
Ok
(
result
)
=>
result
,
Err
(
_
)
=>
return
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
)
Err
(
_
)
=>
return
Err
(
StatusCode
::
INTERNAL_SERVER_ERROR
)
,
};
};
let
output
=
state
.infer
.infer
(
input_length
,
validated_request
)
.await
;
let
output
=
state
.infer
.infer
(
input_length
,
validated_request
)
.await
;
...
@@ -102,11 +129,7 @@ struct ServerState {
...
@@ -102,11 +129,7 @@ struct ServerState {
infer
:
Batcher
,
infer
:
Batcher
,
}
}
pub
async
fn
run
(
pub
async
fn
run
(
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
addr
:
SocketAddr
)
{
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
addr
:
SocketAddr
,
)
{
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
client
.clear_cache
()
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
...
@@ -114,13 +137,16 @@ pub async fn run(
...
@@ -114,13 +137,16 @@ pub async fn run(
let
validation
=
Validation
::
new
(
tokenizer
);
let
validation
=
Validation
::
new
(
tokenizer
);
let
shared_state
=
ServerState
{
let
shared_state
=
ServerState
{
validation
,
infer
};
validation
,
infer
,
};
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.layer
(
Extension
(
shared_state
));
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.layer
(
Extension
(
shared_state
.clone
()))
.route
(
"/health"
,
post
(
liveness
))
.layer
(
Extension
(
shared_state
.clone
()));
axum
::
Server
::
bind
(
&
addr
)
axum
::
Server
::
bind
(
&
addr
)
.serve
(
app
.into_make_service
())
.await
.unwrap
();
.serve
(
app
.into_make_service
())
.await
.unwrap
();
}
}
run.sh
0 → 100755
View file @
bf99afe9
#!/usr/bin/env bash
server_cmd
=
"python server/bloom_inference/main.py
$MODEL_NAME
--num-gpus
$NUM_GPUS
--shard-directory
$MODEL_BASE_PATH
"
$server_cmd
&
FILE
=
/tmp/bloom-inference-0
while
:
do
if
test
-S
"
$FILE
"
;
then
echo
"Text Generation Python gRPC server started"
break
else
echo
"Waiting for Text Generation Python gRPC server to start"
sleep
5
fi
done
sleep
1
exec
"bloom-inference"
server/bloom_inference/model.py
View file @
bf99afe9
...
@@ -220,12 +220,14 @@ class BLOOM:
...
@@ -220,12 +220,14 @@ class BLOOM:
def
__init__
(
self
,
model_name
:
str
):
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
model
=
(
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
().
to
(
self
.
device
)
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
().
to
(
self
.
device
)
.
to
(
dtype
)
)
)
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
...
@@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
...
@@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
if
do_transpose
:
if
do_transpose
:
state_dict
[
key
]
=
state_dict
[
key
].
transpose
(
1
,
0
).
contiguous
()
state_dict
[
key
]
=
state_dict
[
key
].
transpose
(
1
,
0
).
contiguous
()
model
.
load_state_dict
(
state_dict
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
model
.
tie_weights
()
self
.
model
=
model
.
to
(
self
.
device
).
eval
()
self
.
model
=
model
.
to
(
self
.
device
).
eval
()
self
.
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
()
self
.
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/bloom_inference/prepare_weights.py
View file @
bf99afe9
import
torch
import
torch
import
os
import
tempfile
import
json
from
typing
import
BinaryIO
from
joblib
import
Parallel
,
delayed
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
tqdm
import
tqdm
from
tqdm
import
tqdm
MODEL_NAME
=
"bigscience/bloom"
from
huggingface_hub
import
hf_hub_url
from
huggingface_hub.file_download
import
_request_wrapper
,
hf_raise_for_status
def
match_suffix
(
text
,
suffix
):
def
match_suffix
(
text
,
suffix
):
return
text
[
-
len
(
suffix
)
:]
==
suffix
return
text
[
-
len
(
suffix
):]
==
suffix
def
prepare_weights
(
hub_path
:
Path
,
save_path
:
Path
,
tp_world_size
:
int
):
def
http_get
(
url
:
str
,
temp_file
:
BinaryIO
,
*
,
timeout
=
10.0
,
max_retries
=
0
,
):
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
"""
r
=
_request_wrapper
(
method
=
"GET"
,
url
=
url
,
stream
=
True
,
timeout
=
timeout
,
max_retries
=
max_retries
,
)
hf_raise_for_status
(
r
)
for
chunk
in
r
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
# filter out keep-alive new chunks
temp_file
.
write
(
chunk
)
def
cache_download_url
(
url
:
str
,
root_dir
:
Path
):
filename
=
root_dir
/
url
.
split
(
"/"
)[
-
1
]
if
not
filename
.
exists
():
temp_file_manager
=
partial
(
tempfile
.
NamedTemporaryFile
,
mode
=
"wb"
,
dir
=
root_dir
,
delete
=
False
)
with
temp_file_manager
()
as
temp_file
:
http_get
(
url
,
temp_file
)
os
.
replace
(
temp_file
.
name
,
filename
)
return
filename
def
prepare_weights
(
model_name
:
str
,
cache_path
:
Path
,
save_path
:
Path
,
tp_world_size
:
int
):
save_paths
=
[
save_paths
=
[
save_path
/
f
"
{
MODEL_NAME
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
save_path
/
f
"
{
model_name
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
for
tp_rank
in
range
(
tp_world_size
)
for
tp_rank
in
range
(
tp_world_size
)
]
]
...
@@ -20,45 +64,67 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
...
@@ -20,45 +64,67 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
print
(
"Weights are already prepared"
)
print
(
"Weights are already prepared"
)
return
return
cache_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
model_name
==
"bigscience/bloom-560m"
:
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin"
)
cache_download_url
(
url
,
cache_path
)
elif
model_name
==
"bigscience/bloom"
:
url
=
hf_hub_url
(
model_name
,
filename
=
"pytorch_model.bin.index.json"
)
index_path
=
cache_download_url
(
url
,
cache_path
)
with
index_path
.
open
(
"r"
)
as
f
:
index
=
json
.
load
(
f
)
# Get unique file names
weight_files
=
list
(
set
([
filename
for
filename
in
index
[
"weight_map"
].
values
()]))
urls
=
[
hf_hub_url
(
model_name
,
filename
=
filename
)
for
filename
in
weight_files
]
Parallel
(
n_jobs
=
5
)(
delayed
(
cache_download_url
)(
url
,
cache_path
)
for
url
in
tqdm
(
urls
))
else
:
raise
ValueError
(
f
"Unknown model name:
{
model_name
}
"
)
shards_state_dicts
=
[{}
for
_
in
range
(
tp_world_size
)]
shards_state_dicts
=
[{}
for
_
in
range
(
tp_world_size
)]
for
weight_path
in
tqdm
(
hub
_path
.
glob
(
"*.bin"
)):
for
weight_path
in
tqdm
(
Path
(
cache
_path
)
.
glob
(
"*.bin"
)):
state_dict
=
torch
.
load
(
weight_path
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
weight_path
,
map_location
=
"cpu"
)
keys
=
list
(
state_dict
.
keys
())
keys
=
list
(
state_dict
.
keys
())
for
state_name
in
keys
:
for
state_name
in
keys
:
state
=
state_dict
[
state_name
]
state
=
state_dict
[
state_name
]
if
any
(
if
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.bias"
,
"self_attention.query_key_value.bias"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.bias"
,
"mlp.dense_h_to_4h.bias"
,
"word_embeddings.weight"
,
"word_embeddings.weight"
,
"lm_head.weight"
,
]
]
):
):
output_size
=
state
.
shape
[
0
]
output_size
=
state
.
shape
[
0
]
assert
output_size
%
tp_world_size
==
0
assert
output_size
%
tp_world_size
==
0
block_size
=
output_size
//
tp_world_size
block_size
=
output_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
assert
len
(
sharded_weights
)
==
tp_world_size
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
0
]
==
block_size
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
if
match_suffix
(
state_name
,
"lm_head.weight"
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
match_suffix
(
state_name
,
"lm_head.weight"
):
else
:
output_size
=
state
.
shape
[
0
]
shards_state_dicts
[
tp_rank
][
assert
output_size
%
tp_world_size
==
0
"transformer."
+
state_name
block_size
=
output_size
//
tp_world_size
]
=
shard
.
detach
().
clone
()
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
elif
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.dense.weight"
,
"self_attention.dense.weight"
,
"mlp.dense_4h_to_h.weight"
,
"mlp.dense_4h_to_h.weight"
,
"lm_head.weight"
,
]
]
):
):
input_size
=
state
.
shape
[
1
]
input_size
=
state
.
shape
[
1
]
assert
input_size
%
tp_world_size
==
0
assert
input_size
%
tp_world_size
==
0
...
@@ -66,40 +132,31 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
...
@@ -66,40 +132,31 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
assert
len
(
sharded_weights
)
==
tp_world_size
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
1
]
==
block_size
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
if
match_suffix
(
state_name
,
"lm_head.weight"
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
else
:
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
elif
any
(
match_suffix
(
state_name
,
candidate
)
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
for
candidate
in
[
"self_attention.dense.bias"
,
"self_attention.dense.bias"
,
"mlp.dense_4h_to_h.bias"
,
"mlp.dense_4h_to_h.bias"
,
]
]
):
):
shards_state_dicts
[
0
][
shards_state_dicts
[
0
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
for
tp_rank
in
range
(
1
,
tp_world_size
):
for
tp_rank
in
range
(
1
,
tp_world_size
):
shards_state_dicts
[
tp_rank
][
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
torch
.
zeros_like
(
state
)
"transformer."
+
state_name
]
=
torch
.
zeros_like
(
state
)
else
:
else
:
# We duplicate parameters across tp ranks
# We duplicate parameters across tp ranks
for
tp_rank
in
range
(
tp_world_size
):
for
tp_rank
in
range
(
tp_world_size
):
shards_state_dicts
[
tp_rank
][
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
del
state_dict
[
state_name
]
# delete key from state_dict
del
state_dict
[
state_name
]
# delete key from state_dict
del
state
# delete tensor
del
state
# delete tensor
del
state_dict
# we save state_dict
# we save state_dict
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
zip
(
save_paths
,
shards_state_dicts
)
zip
(
save_paths
,
shards_state_dicts
)
):
):
save_paths
.
append
(
save_path
)
save_paths
.
append
(
save_path
)
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -116,9 +173,10 @@ if __name__ == "__main__":
...
@@ -116,9 +173,10 @@ if __name__ == "__main__":
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--hub-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--model-name"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--cache-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--save-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--save-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--world-size"
,
required
=
True
,
type
=
int
)
parser
.
add_argument
(
"--world-size"
,
required
=
True
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
prepare_weights
(
Path
(
args
.
hub
_path
),
Path
(
args
.
save_path
),
args
.
world_size
)
prepare_weights
(
args
.
model_name
,
Path
(
args
.
cache
_path
),
Path
(
args
.
save_path
),
args
.
world_size
)
server/bloom_inference/shard_model.py
deleted
100644 → 0
View file @
39df4d99
from
pathlib
import
Path
import
torch
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
def
match_suffix
(
text
,
suffix
):
return
text
[
-
len
(
suffix
)
:]
==
suffix
def
shard_model
(
model_name
:
str
,
path
:
Path
,
tp_world_size
:
int
,
dtype
:
torch
.
dtype
):
"""BLOOM specific sharding mechanism"""
save_paths
=
[
path
/
f
"
{
model_name
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
for
tp_rank
in
range
(
tp_world_size
)
]
if
all
(
save_path
.
exists
()
for
save_path
in
save_paths
):
print
(
"Loading already cached values"
)
return
save_paths
model
:
nn
.
Module
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
local_files_only
=
True
)
shards_state_dicts
=
[{}
for
_
in
range
(
tp_world_size
)]
state_dict
=
model
.
state_dict
()
keys
=
list
(
state_dict
.
keys
())
for
state_name
in
keys
:
print
(
state_name
)
state
=
state_dict
[
state_name
]
if
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.bias"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.bias"
,
"transformer.word_embeddings.weight"
,
"lm_head.weight"
,
]
):
output_size
=
state
.
shape
[
0
]
assert
output_size
%
tp_world_size
==
0
block_size
=
output_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
0
]
==
block_size
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.weight"
,
"mlp.dense_4h_to_h.weight"
,
"lm_head.weight"
,
]
):
input_size
=
state
.
shape
[
1
]
assert
input_size
%
tp_world_size
==
0
block_size
=
input_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
1
]
==
block_size
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.bias"
,
"mlp.dense_4h_to_h.bias"
,
]
):
shards_state_dicts
[
0
][
state_name
]
=
state
.
detach
().
clone
()
for
tp_rank
in
range
(
1
,
tp_world_size
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
torch
.
zeros_like
(
state
)
else
:
# We duplicate parameters across tp ranks
for
tp_rank
in
range
(
tp_world_size
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
state
.
detach
().
clone
()
del
state_dict
[
state_name
]
# delete key from state_dict
del
state
# delete tensor
# we save state_dict
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
zip
(
save_paths
,
shards_state_dicts
)
):
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
shard_state_dict
,
save_path
)
save_paths
.
append
(
save_path
)
return
save_paths
if
__name__
==
"__main__"
:
model_name
=
"bigscience/bloom"
save_path
=
Path
(
"/data/shards"
)
tp_world_size
=
8
dtype
=
torch
.
bfloat16
shard_model
(
model_name
,
save_path
,
tp_world_size
=
tp_world_size
,
dtype
=
dtype
)
server/poetry.lock
View file @
bf99afe9
...
@@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
...
@@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
protobuf = ">=4.21.3,<5.0dev"
protobuf = ">=4.21.3,<5.0dev"
setuptools = "*"
setuptools = "*"
[[package]]
name = "joblib"
version = "1.2.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
[[package]]
name = "numpy"
name = "numpy"
version = "1.23.3"
version = "1.23.3"
...
@@ -197,7 +205,7 @@ python-versions = ">=3.7"
...
@@ -197,7 +205,7 @@ python-versions = ">=3.7"
[metadata]
[metadata]
lock-version = "1.1"
lock-version = "1.1"
python-versions = "^3.9"
python-versions = "^3.9"
content-hash = "
cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25
"
content-hash = "
f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2
"
[metadata.files]
[metadata.files]
accelerate = [
accelerate = [
...
@@ -310,6 +318,10 @@ grpcio-tools = [
...
@@ -310,6 +318,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
]
]
joblib = [
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
]
numpy = [
numpy = [
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},
...
...
server/pyproject.toml
View file @
bf99afe9
...
@@ -12,6 +12,7 @@ torch = "^1.12.1"
...
@@ -12,6 +12,7 @@ torch = "^1.12.1"
typer
=
"^0.6.1"
typer
=
"^0.6.1"
grpcio-reflection
=
"^1.49.1"
grpcio-reflection
=
"^1.49.1"
accelerate
=
"^0.12.0"
accelerate
=
"^0.12.0"
joblib
=
"^1.2.0"
[tool.poetry.group.dev.dependencies]
[tool.poetry.group.dev.dependencies]
grpcio-tools
=
"^1.49.1"
grpcio-tools
=
"^1.49.1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment