Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
295831a4
Commit
295831a4
authored
Oct 08, 2022
by
Olivier Dehaene
Browse files
Init
parents
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1247 additions
and
0 deletions
+1247
-0
server/bloom_inference/__pycache__/cache.cpython-39.pyc
server/bloom_inference/__pycache__/cache.cpython-39.pyc
+0
-0
server/bloom_inference/__pycache__/model.cpython-39.pyc
server/bloom_inference/__pycache__/model.cpython-39.pyc
+0
-0
server/bloom_inference/__pycache__/server.cpython-39.pyc
server/bloom_inference/__pycache__/server.cpython-39.pyc
+0
-0
server/bloom_inference/__pycache__/shard_model.cpython-39.pyc
...er/bloom_inference/__pycache__/shard_model.cpython-39.pyc
+0
-0
server/bloom_inference/__pycache__/utils.cpython-39.pyc
server/bloom_inference/__pycache__/utils.cpython-39.pyc
+0
-0
server/bloom_inference/cache.py
server/bloom_inference/cache.py
+48
-0
server/bloom_inference/main.py
server/bloom_inference/main.py
+30
-0
server/bloom_inference/model.py
server/bloom_inference/model.py
+428
-0
server/bloom_inference/pb/__init__.py
server/bloom_inference/pb/__init__.py
+0
-0
server/bloom_inference/pb/__init__.py-e
server/bloom_inference/pb/__init__.py-e
+0
-0
server/bloom_inference/pb/__pycache__/__init__.cpython-39.pyc
...er/bloom_inference/pb/__pycache__/__init__.cpython-39.pyc
+0
-0
server/bloom_inference/pb/__pycache__/generate_pb2.cpython-39.pyc
...loom_inference/pb/__pycache__/generate_pb2.cpython-39.pyc
+0
-0
server/bloom_inference/pb/__pycache__/generate_pb2_grpc.cpython-39.pyc
...inference/pb/__pycache__/generate_pb2_grpc.cpython-39.pyc
+0
-0
server/bloom_inference/pb/generate_pb2.py
server/bloom_inference/pb/generate_pb2.py
+43
-0
server/bloom_inference/pb/generate_pb2.py-e
server/bloom_inference/pb/generate_pb2.py-e
+43
-0
server/bloom_inference/pb/generate_pb2_grpc.py
server/bloom_inference/pb/generate_pb2_grpc.py
+169
-0
server/bloom_inference/pb/generate_pb2_grpc.py-e
server/bloom_inference/pb/generate_pb2_grpc.py-e
+169
-0
server/bloom_inference/prepare_weights.py
server/bloom_inference/prepare_weights.py
+124
-0
server/bloom_inference/server.py
server/bloom_inference/server.py
+91
-0
server/bloom_inference/shard_model.py
server/bloom_inference/shard_model.py
+102
-0
No files found.
server/bloom_inference/__pycache__/cache.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/__pycache__/model.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/__pycache__/server.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/__pycache__/shard_model.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/__pycache__/utils.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/cache.py
0 → 100644
View file @
295831a4
import
torch
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Optional
,
List
from
bloom_inference.pb
import
generate_pb2
from
bloom_inference.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
CacheEntry
:
batch_id
:
int
request_ids
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
def
__len__
(
self
):
return
len
(
self
.
request_ids
)
def
to_pb
(
self
):
return
generate_pb2
.
CacheEntry
(
id
=
self
.
batch_id
,
request_ids
=
self
.
request_ids
,
sequence_length
=
max
(
len
(
entry
)
for
entry
in
self
.
all_input_ids
),
)
class
Cache
:
def
__init__
(
self
):
self
.
cache
:
Dict
[
str
,
CacheEntry
]
=
{}
def
pop
(
self
,
batch_id
:
str
)
->
Optional
[
CacheEntry
]:
return
self
.
cache
.
pop
(
batch_id
,
None
)
def
set
(
self
,
entry
:
CacheEntry
):
if
entry
is
not
None
:
self
.
cache
[
entry
.
batch_id
]
=
entry
def
delete
(
self
,
batch_id
:
str
):
del
self
.
cache
[
batch_id
]
def
clear
(
self
):
self
.
cache
.
clear
()
def
__len__
(
self
):
return
len
(
self
.
cache
.
keys
())
server/bloom_inference/main.py
0 → 100644
View file @
295831a4
import
typer
from
pathlib
import
Path
from
torch.distributed.launcher
import
launch_agent
,
LaunchConfig
from
typing
import
Optional
from
bloom_inference.server
import
serve
def
main
(
model_name
:
str
,
num_gpus
:
int
=
1
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
if
num_gpus
==
1
:
serve
(
model_name
,
False
,
shard_directory
)
else
:
config
=
LaunchConfig
(
min_nodes
=
1
,
max_nodes
=
1
,
nproc_per_node
=
num_gpus
,
rdzv_backend
=
"c10d"
,
max_restarts
=
0
,
)
launch_agent
(
config
,
serve
,
[
model_name
,
True
,
shard_directory
])
if
__name__
==
"__main__"
:
typer
.
run
(
main
)
server/bloom_inference/model.py
0 → 100644
View file @
295831a4
import
torch
import
torch.distributed
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
List
,
Tuple
,
Optional
,
Dict
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
from
transformers.modeling_utils
import
no_init_weights
from
bloom_inference.cache
import
CacheEntry
from
bloom_inference.pb
import
generate_pb2
from
bloom_inference.shard_model
import
shard_model
,
match_suffix
from
bloom_inference.utils
import
(
StoppingCriteria
,
NextTokenChooser
,
initialize_torch_distributed
,
set_default_dtype
,
)
torch
.
manual_seed
(
0
)
@
dataclass
class
Batch
:
batch_id
:
int
request_ids
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
@
classmethod
def
from_batch_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Batch"
:
request_ids
=
[]
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Parse batch
for
r
in
pb
.
requests
:
request_ids
.
append
(
r
.
id
)
inputs
.
append
(
r
.
inputs
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
pb
.
id
,
request_ids
,
input_ids
,
all_input_ids
,
next_token_choosers
,
stopping_criterias
,
)
@
classmethod
def
from_cache_entry
(
cls
,
cache_entry
:
CacheEntry
)
->
"Batch"
:
return
cls
(
cache_entry
.
batch_id
,
cache_entry
.
request_ids
,
cache_entry
.
input_ids
,
cache_entry
.
all_input_ids
,
cache_entry
.
next_token_choosers
,
cache_entry
.
stopping_criterias
,
)
@
classmethod
def
from_batch_cached_pb
(
cls
,
pb
:
generate_pb2
.
BatchCached
,
cache
)
->
"Batch"
:
if
len
(
pb
.
batch_cached_ids
)
==
1
:
cache_entry
=
cache
.
pop
(
pb
.
batch_cached_ids
[
0
])
if
cache_entry
is
None
:
raise
ValueError
(
f
"Batch ID
{
pb
.
batch_id
}
not found in cache"
)
return
cls
.
from_cache_entry
(
cache_entry
)
total_batch_size
=
pb
.
total_batch_size
max_sequence_length
=
pb
.
max_sequence_length
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
request_ids
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
start_index
=
0
for
i
,
batch_id
in
enumerate
(
pb
.
batch_cached_ids
):
cache_entry
=
cache
.
pop
(
batch_id
)
if
cache_entry
is
None
:
raise
ValueError
(
f
"Batch ID
{
batch_id
}
not found in cache"
)
request_ids
.
extend
(
cache_entry
.
request_ids
)
all_input_ids
.
extend
(
cache_entry
.
all_input_ids
)
next_token_choosers
.
extend
(
cache_entry
.
next_token_choosers
)
stopping_criterias
.
extend
(
cache_entry
.
stopping_criterias
)
batch_size
=
len
(
cache_entry
.
request_ids
)
end_index
=
start_index
+
batch_size
sequence_length
=
max
(
len
(
entry
)
for
entry
in
cache_entry
.
all_input_ids
)
if
input_ids
[
"input_ids"
]
is
None
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
cache_entry
.
input_ids
[
"input_ids"
].
dtype
,
device
=
cache_entry
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
cache_entry
.
input_ids
[
"input_ids"
]
if
input_ids
[
"attention_mask"
]
is
None
:
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
cache_entry
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
cache_entry
.
input_ids
[
"attention_mask"
].
device
,
)
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
sequence_length
:
]
=
cache_entry
.
input_ids
[
"attention_mask"
][:,
-
sequence_length
:]
for
j
,
past
in
enumerate
(
cache_entry
.
input_ids
[
"past_key_values"
]):
# TODO: this could be done without the views by using indices
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
past_keys
=
past_keys
.
view
(
batch_size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch_size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
sequence_length
-
1
):
]
=
past_keys
[:,
:,
:,
-
(
sequence_length
-
1
):]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
sequence_length
-
1
):,
:
]
=
past_values
[:,
:,
-
(
sequence_length
-
1
):,
:]
if
(
i
+
1
)
==
len
(
pb
.
batch_cached_ids
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch_size
assert
pb
.
request_ids
==
request_ids
return
cls
(
pb
.
id
,
request_ids
,
input_ids
,
all_input_ids
,
next_token_choosers
,
stopping_criterias
,
)
@
dataclass
class
FinishedGeneration
:
request_id
:
str
output
:
str
def
to_pb
(
self
)
->
generate_pb2
.
FinishedGeneration
:
return
generate_pb2
.
FinishedGeneration
(
id
=
self
.
request_id
,
output
=
self
.
output
)
class
BLOOM
:
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
).
eval
().
to
(
self
.
device
)
)
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
):
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
def
generate_token
(
self
,
batch
:
Batch
)
->
Tuple
[
List
[
FinishedGeneration
],
Optional
[
CacheEntry
]]:
with
torch
.
no_grad
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
cache_indices
=
[]
cache_past_indices
=
[]
# New input_ids for next forward; keep in cache
cache_next_input_ids
=
[]
cache_all_input_ids
=
[]
# Finished requests
finished_generations
:
List
[
FinishedGeneration
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
request_ids
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request_id
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request id
finished_generations
.
append
(
FinishedGeneration
(
request_id
,
output
))
# must be added to the cache
else
:
cache_indices
.
append
(
i
)
cache_past_indices
.
extend
([
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)])
cache_next_input_ids
.
append
(
next_token
)
cache_all_input_ids
.
append
(
all_tokens
)
# No cache is needed, we finished all generations in the batch
if
not
cache_indices
:
return
finished_generations
,
None
# If we finished at least one generation
cache_input_ids
=
{
"input_ids"
:
torch
.
cat
(
cache_next_input_ids
,
dim
=
0
)}
if
finished_generations
:
# Apply indices to attention mask, past key values and other items that need to be cached
cache_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
cache_indices
]
cache_input_ids
[
"past_key_values"
]
=
[
(
keys
[
cache_past_indices
],
values
[
cache_past_indices
])
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
cache_request_ids
=
[
batch
.
request_ids
[
i
]
for
i
in
cache_indices
]
cache_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
cache_indices
]
cache_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
cache_indices
]
else
:
cache_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
cache_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
cache_request_ids
=
batch
.
request_ids
cache_next_token_choosers
=
batch
.
next_token_choosers
cache_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
cache_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
cache_input_ids
[
"attention_mask"
],
torch
.
ones
((
cache_input_ids
[
"attention_mask"
].
shape
[
0
],
1
)).
to
(
cache_input_ids
[
"attention_mask"
].
device
),
],
dim
=
1
,
)
cache_entry
=
CacheEntry
(
batch
.
batch_id
,
cache_request_ids
,
cache_input_ids
,
cache_all_input_ids
,
cache_next_token_choosers
,
cache_stopping_criterias
,
)
return
finished_generations
,
cache_entry
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_name
:
str
,
shard_directory
:
Path
):
super
(
BLOOM
,
self
).
__init__
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
# shard state_dict
if
self
.
master
:
# TODO @thomasw21 do some caching
shard_state_dict_paths
=
shard_model
(
model_name
,
shard_directory
,
tp_world_size
=
self
.
world_size
,
dtype
=
dtype
)
shard_state_dict_paths
=
[
str
(
path
.
absolute
())
for
path
in
shard_state_dict_paths
]
else
:
shard_state_dict_paths
=
[
None
]
*
self
.
world_size
torch
.
distributed
.
broadcast_object_list
(
shard_state_dict_paths
,
src
=
0
,
group
=
self
.
process_group
)
shard_state_dict_path
=
shard_state_dict_paths
[
self
.
rank
]
config
=
AutoConfig
.
from_pretrained
(
model_name
,
slow_but_exact
=
False
,
tp_parallel
=
True
)
config
.
pad_token_id
=
3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
with
set_default_dtype
(
dtype
):
with
no_init_weights
():
# we can probably set the device to `meta` here?
model
=
AutoModelForCausalLM
.
from_config
(
config
).
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
# print_rank_0(f"Initialized model")
state_dict
=
torch
.
load
(
shard_state_dict_path
)
# TODO @thomasw21: HACK in order to transpose all weight prior
for
key
in
state_dict
.
keys
():
do_transpose
=
False
if
not
match_suffix
(
key
,
"weight"
):
continue
for
potential_suffix
in
[
"self_attention.query_key_value.weight"
,
"self_attention.dense.weight"
,
"dense_h_to_4h.weight"
,
"dense_4h_to_h.weight"
,
]:
if
match_suffix
(
key
,
potential_suffix
):
do_transpose
=
True
if
do_transpose
:
state_dict
[
key
]
=
state_dict
[
key
].
transpose
(
1
,
0
).
contiguous
()
model
.
load_state_dict
(
state_dict
)
self
.
model
=
model
.
to
(
self
.
device
).
eval
()
self
.
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
logits_shard
=
outputs
.
logits
[:,
-
1
,
:].
contiguous
()
batch_size
,
vocab_shard_size
=
logits_shard
.
shape
vocab_size
=
self
.
world_size
*
vocab_shard_size
logits
=
[
torch
.
empty_like
(
logits_shard
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
logits_shard
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
1
).
view
(
batch_size
,
1
,
vocab_size
)
outputs
.
logits
=
logits
return
outputs
server/bloom_inference/pb/__init__.py
0 → 100644
View file @
295831a4
server/bloom_inference/pb/__init__.py-e
0 → 100644
View file @
295831a4
server/bloom_inference/pb/__pycache__/__init__.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/pb/__pycache__/generate_pb2.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/pb/__pycache__/generate_pb2_grpc.cpython-39.pyc
0 → 100644
View file @
295831a4
File added
server/bloom_inference/pb/generate_pb2.py
0 → 100644
View file @
295831a4
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from
google.protobuf.internal
import
builder
as
_builder
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
descriptor_pool
as
_descriptor_pool
from
google.protobuf
import
symbol_database
as
_symbol_database
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor_pool
.
Default
().
AddSerializedFile
(
b
'
\n\x0e
generate.proto
\x12\x0b
generate.v1
\"
(
\n\x18
ServiceDiscoveryResponse
\x12\x0c\n\x04
urls
\x18\x01
\x03
(
\t\"
^
\n\x16
LogitsWarperParameters
\x12\x13\n\x0b
temperature
\x18\x01
\x01
(
\x02\x12\r\n\x05
top_k
\x18\x02
\x01
(
\r\x12\r\n\x05
top_p
\x18\x03
\x01
(
\x02\x12\x11\n\t
do_sample
\x18\x04
\x01
(
\x08\"
v
\n\x07
Request
\x12\n\n\x02
id
\x18\x01
\x01
(
\x04\x12\x0e\n\x06
inputs
\x18\x02
\x01
(
\t\x12\x37\n\n
parameters
\x18\x03
\x01
(
\x0b\x32
#.generate.v1.LogitsWarperParameters
\x12\x16\n\x0e
max_new_tokens
\x18\x04
\x01
(
\r\"
;
\n\x05\x42\x61
tch
\x12\n\n\x02
id
\x18\x01
\x01
(
\x04\x12
&
\n\x08
requests
\x18\x02
\x03
(
\x0b\x32\x14
.generate.v1.Request
\"\x7f\n\x0b\x42\x61
tchCached
\x12\n\n\x02
id
\x18\x01
\x01
(
\x04\x12\x13\n\x0b
request_ids
\x18\x02
\x03
(
\x04\x12\x18\n\x10\x62\x61
tch_cached_ids
\x18\x03
\x03
(
\x04\x12\x18\n\x10
total_batch_size
\x18\x04
\x01
(
\r\x12\x1b\n\x13
max_sequence_length
\x18\x05
\x01
(
\r\"
0
\n\x12\x46
inishedGeneration
\x12\n\n\x02
id
\x18\x01
\x01
(
\x04\x12\x0e\n\x06
output
\x18\x02
\x01
(
\t\"
F
\n\n
CacheEntry
\x12\n\n\x02
id
\x18\x01
\x01
(
\x04\x12\x13\n\x0b
request_ids
\x18\x02
\x03
(
\x04\x12\x17\n\x0f
sequence_length
\x18\x03
\x01
(
\r\"\x80\x01\n\x08
Response
\x12\x31\n\x08\x66
inished
\x18\x01
\x03
(
\x0b\x32\x1f
.generate.v1.FinishedGeneration
\x12\x31\n\x0b\x63\x61\x63
he_entry
\x18\x02
\x01
(
\x0b\x32\x17
.generate.v1.CacheEntryH
\x00\x88\x01\x01\x42\x0e\n\x0c
_cache_entry
\"\x07\n\x05\x45
mpty2
\x94\x02\n\x0e
TextGeneration
\x12
O
\n\x10
ServiceDiscovery
\x12\x12
.generate.v1.Empty
\x1a
%.generate.v1.ServiceDiscoveryResponse
\"\x00\x12\x34\n\n
ClearCache
\x12\x12
.generate.v1.Empty
\x1a\x12
.generate.v1.Empty
\x12\x35\n\x08
Generate
\x12\x12
.generate.v1.Batch
\x1a\x15
.generate.v1.Response
\x12\x44\n\x11
GenerateWithCache
\x12\x18
.generate.v1.BatchCached
\x1a\x15
.generate.v1.Responseb
\x06
proto3'
)
_builder
.
BuildMessageAndEnumDescriptors
(
DESCRIPTOR
,
globals
())
_builder
.
BuildTopDescriptorsAndMessages
(
DESCRIPTOR
,
'generate_pb2'
,
globals
())
if
_descriptor
.
_USE_C_DESCRIPTORS
==
False
:
DESCRIPTOR
.
_options
=
None
_SERVICEDISCOVERYRESPONSE
.
_serialized_start
=
31
_SERVICEDISCOVERYRESPONSE
.
_serialized_end
=
71
_LOGITSWARPERPARAMETERS
.
_serialized_start
=
73
_LOGITSWARPERPARAMETERS
.
_serialized_end
=
167
_REQUEST
.
_serialized_start
=
169
_REQUEST
.
_serialized_end
=
287
_BATCH
.
_serialized_start
=
289
_BATCH
.
_serialized_end
=
348
_BATCHCACHED
.
_serialized_start
=
350
_BATCHCACHED
.
_serialized_end
=
477
_FINISHEDGENERATION
.
_serialized_start
=
479
_FINISHEDGENERATION
.
_serialized_end
=
527
_CACHEENTRY
.
_serialized_start
=
529
_CACHEENTRY
.
_serialized_end
=
599
_RESPONSE
.
_serialized_start
=
602
_RESPONSE
.
_serialized_end
=
730
_EMPTY
.
_serialized_start
=
732
_EMPTY
.
_serialized_end
=
739
_TEXTGENERATION
.
_serialized_start
=
742
_TEXTGENERATION
.
_serialized_end
=
1018
# @@protoc_insertion_point(module_scope)
server/bloom_inference/pb/generate_pb2.py-e
0 → 100644
View file @
295831a4
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0egenerate.proto\x12\x0bgenerate.v1\"(\n\x18ServiceDiscoveryResponse\x12\x0c\n\x04urls\x18\x01 \x03(\t\"^\n\x16LogitsWarperParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\tdo_sample\x18\x04 \x01(\x08\"v\n\x07Request\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06inputs\x18\x02 \x01(\t\x12\x37\n\nparameters\x18\x03 \x01(\x0b\x32#.generate.v1.LogitsWarperParameters\x12\x16\n\x0emax_new_tokens\x18\x04 \x01(\r\";\n\x05\x42\x61tch\x12\n\n\x02id\x18\x01 \x01(\x04\x12&\n\x08requests\x18\x02 \x03(\x0b\x32\x14.generate.v1.Request\"\x7f\n\x0b\x42\x61tchCached\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x18\n\x10\x62\x61tch_cached_ids\x18\x03 \x03(\x04\x12\x18\n\x10total_batch_size\x18\x04 \x01(\r\x12\x1b\n\x13max_sequence_length\x18\x05 \x01(\r\"0\n\x12\x46inishedGeneration\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06output\x18\x02 \x01(\t\"F\n\nCacheEntry\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x17\n\x0fsequence_length\x18\x03 \x01(\r\"\x80\x01\n\x08Response\x12\x31\n\x08\x66inished\x18\x01 \x03(\x0b\x32\x1f.generate.v1.FinishedGeneration\x12\x31\n\x0b\x63\x61\x63he_entry\x18\x02 \x01(\x0b\x32\x17.generate.v1.CacheEntryH\x00\x88\x01\x01\x42\x0e\n\x0c_cache_entry\"\x07\n\x05\x45mpty2\x94\x02\n\x0eTextGeneration\x12O\n\x10ServiceDiscovery\x12\x12.generate.v1.Empty\x1a%.generate.v1.ServiceDiscoveryResponse\"\x00\x12\x34\n\nClearCache\x12\x12.generate.v1.Empty\x1a\x12.generate.v1.Empty\x12\x35\n\x08Generate\x12\x12.generate.v1.Batch\x1a\x15.generate.v1.Response\x12\x44\n\x11GenerateWithCache\x12\x18.generate.v1.BatchCached\x1a\x15.generate.v1.Responseb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generate_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SERVICEDISCOVERYRESPONSE._serialized_start=31
_SERVICEDISCOVERYRESPONSE._serialized_end=71
_LOGITSWARPERPARAMETERS._serialized_start=73
_LOGITSWARPERPARAMETERS._serialized_end=167
_REQUEST._serialized_start=169
_REQUEST._serialized_end=287
_BATCH._serialized_start=289
_BATCH._serialized_end=348
_BATCHCACHED._serialized_start=350
_BATCHCACHED._serialized_end=477
_FINISHEDGENERATION._serialized_start=479
_FINISHEDGENERATION._serialized_end=527
_CACHEENTRY._serialized_start=529
_CACHEENTRY._serialized_end=599
_RESPONSE._serialized_start=602
_RESPONSE._serialized_end=730
_EMPTY._serialized_start=732
_EMPTY._serialized_end=739
_TEXTGENERATION._serialized_start=742
_TEXTGENERATION._serialized_end=1018
# @@protoc_insertion_point(module_scope)
server/bloom_inference/pb/generate_pb2_grpc.py
0 → 100644
View file @
295831a4
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import
grpc
from
.
import
generate_pb2
as
generate__pb2
class
TextGenerationStub
(
object
):
"""Missing associated documentation comment in .proto file."""
def
__init__
(
self
,
channel
):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self
.
ServiceDiscovery
=
channel
.
unary_unary
(
'/generate.v1.TextGeneration/ServiceDiscovery'
,
request_serializer
=
generate__pb2
.
Empty
.
SerializeToString
,
response_deserializer
=
generate__pb2
.
ServiceDiscoveryResponse
.
FromString
,
)
self
.
ClearCache
=
channel
.
unary_unary
(
'/generate.v1.TextGeneration/ClearCache'
,
request_serializer
=
generate__pb2
.
Empty
.
SerializeToString
,
response_deserializer
=
generate__pb2
.
Empty
.
FromString
,
)
self
.
Generate
=
channel
.
unary_unary
(
'/generate.v1.TextGeneration/Generate'
,
request_serializer
=
generate__pb2
.
Batch
.
SerializeToString
,
response_deserializer
=
generate__pb2
.
Response
.
FromString
,
)
self
.
GenerateWithCache
=
channel
.
unary_unary
(
'/generate.v1.TextGeneration/GenerateWithCache'
,
request_serializer
=
generate__pb2
.
BatchCached
.
SerializeToString
,
response_deserializer
=
generate__pb2
.
Response
.
FromString
,
)
class
TextGenerationServicer
(
object
):
"""Missing associated documentation comment in .proto file."""
def
ServiceDiscovery
(
self
,
request
,
context
):
"""/ Service discovery
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
ClearCache
(
self
,
request
,
context
):
"""/ Empties batch cache
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
Generate
(
self
,
request
,
context
):
"""/ Generate tokens for a batch without cache
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
GenerateWithCache
(
self
,
request
,
context
):
"""/ Generate tokens for a batch with cache
"""
context
.
set_code
(
grpc
.
StatusCode
.
UNIMPLEMENTED
)
context
.
set_details
(
'Method not implemented!'
)
raise
NotImplementedError
(
'Method not implemented!'
)
def
add_TextGenerationServicer_to_server
(
servicer
,
server
):
rpc_method_handlers
=
{
'ServiceDiscovery'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
ServiceDiscovery
,
request_deserializer
=
generate__pb2
.
Empty
.
FromString
,
response_serializer
=
generate__pb2
.
ServiceDiscoveryResponse
.
SerializeToString
,
),
'ClearCache'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
ClearCache
,
request_deserializer
=
generate__pb2
.
Empty
.
FromString
,
response_serializer
=
generate__pb2
.
Empty
.
SerializeToString
,
),
'Generate'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
Generate
,
request_deserializer
=
generate__pb2
.
Batch
.
FromString
,
response_serializer
=
generate__pb2
.
Response
.
SerializeToString
,
),
'GenerateWithCache'
:
grpc
.
unary_unary_rpc_method_handler
(
servicer
.
GenerateWithCache
,
request_deserializer
=
generate__pb2
.
BatchCached
.
FromString
,
response_serializer
=
generate__pb2
.
Response
.
SerializeToString
,
),
}
generic_handler
=
grpc
.
method_handlers_generic_handler
(
'generate.v1.TextGeneration'
,
rpc_method_handlers
)
server
.
add_generic_rpc_handlers
((
generic_handler
,))
# This class is part of an EXPERIMENTAL API.
class
TextGeneration
(
object
):
"""Missing associated documentation comment in .proto file."""
@
staticmethod
def
ServiceDiscovery
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/generate.v1.TextGeneration/ServiceDiscovery'
,
generate__pb2
.
Empty
.
SerializeToString
,
generate__pb2
.
ServiceDiscoveryResponse
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
)
@
staticmethod
def
ClearCache
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/generate.v1.TextGeneration/ClearCache'
,
generate__pb2
.
Empty
.
SerializeToString
,
generate__pb2
.
Empty
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
)
@
staticmethod
def
Generate
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/generate.v1.TextGeneration/Generate'
,
generate__pb2
.
Batch
.
SerializeToString
,
generate__pb2
.
Response
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
)
@
staticmethod
def
GenerateWithCache
(
request
,
target
,
options
=
(),
channel_credentials
=
None
,
call_credentials
=
None
,
insecure
=
False
,
compression
=
None
,
wait_for_ready
=
None
,
timeout
=
None
,
metadata
=
None
):
return
grpc
.
experimental
.
unary_unary
(
request
,
target
,
'/generate.v1.TextGeneration/GenerateWithCache'
,
generate__pb2
.
BatchCached
.
SerializeToString
,
generate__pb2
.
Response
.
FromString
,
options
,
channel_credentials
,
insecure
,
call_credentials
,
compression
,
wait_for_ready
,
timeout
,
metadata
)
server/bloom_inference/pb/generate_pb2_grpc.py-e
0 → 100644
View file @
295831a4
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import generate_pb2 as generate__pb2
class TextGenerationStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ServiceDiscovery = channel.unary_unary(
'/generate.v1.TextGeneration/ServiceDiscovery',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.ServiceDiscoveryResponse.FromString,
)
self.ClearCache = channel.unary_unary(
'/generate.v1.TextGeneration/ClearCache',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.Empty.FromString,
)
self.Generate = channel.unary_unary(
'/generate.v1.TextGeneration/Generate',
request_serializer=generate__pb2.Batch.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
self.GenerateWithCache = channel.unary_unary(
'/generate.v1.TextGeneration/GenerateWithCache',
request_serializer=generate__pb2.BatchCached.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
class TextGenerationServicer(object):
"""Missing associated documentation comment in .proto file."""
def ServiceDiscovery(self, request, context):
"""/ Service discovery
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClearCache(self, request, context):
"""/ Empties batch cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Generate(self, request, context):
"""/ Generate tokens for a batch without cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateWithCache(self, request, context):
"""/ Generate tokens for a batch with cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TextGenerationServicer_to_server(servicer, server):
rpc_method_handlers = {
'ServiceDiscovery': grpc.unary_unary_rpc_method_handler(
servicer.ServiceDiscovery,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.ServiceDiscoveryResponse.SerializeToString,
),
'ClearCache': grpc.unary_unary_rpc_method_handler(
servicer.ClearCache,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.Empty.SerializeToString,
),
'Generate': grpc.unary_unary_rpc_method_handler(
servicer.Generate,
request_deserializer=generate__pb2.Batch.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
'GenerateWithCache': grpc.unary_unary_rpc_method_handler(
servicer.GenerateWithCache,
request_deserializer=generate__pb2.BatchCached.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'generate.v1.TextGeneration', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class TextGeneration(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ServiceDiscovery(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ServiceDiscovery',
generate__pb2.Empty.SerializeToString,
generate__pb2.ServiceDiscoveryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ClearCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ClearCache',
generate__pb2.Empty.SerializeToString,
generate__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/Generate',
generate__pb2.Batch.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateWithCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/GenerateWithCache',
generate__pb2.BatchCached.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
server/bloom_inference/prepare_weights.py
0 → 100644
View file @
295831a4
import
torch
from
pathlib
import
Path
from
tqdm
import
tqdm
MODEL_NAME
=
"bigscience/bloom"
def
match_suffix
(
text
,
suffix
):
return
text
[
-
len
(
suffix
)
:]
==
suffix
def
prepare_weights
(
hub_path
:
Path
,
save_path
:
Path
,
tp_world_size
:
int
):
save_paths
=
[
save_path
/
f
"
{
MODEL_NAME
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
for
tp_rank
in
range
(
tp_world_size
)
]
if
all
(
save_path
.
exists
()
for
save_path
in
save_paths
):
print
(
"Weights are already prepared"
)
return
shards_state_dicts
=
[{}
for
_
in
range
(
tp_world_size
)]
for
weight_path
in
tqdm
(
hub_path
.
glob
(
"*.bin"
)):
state_dict
=
torch
.
load
(
weight_path
,
map_location
=
"cpu"
)
keys
=
list
(
state_dict
.
keys
())
for
state_name
in
keys
:
state
=
state_dict
[
state_name
]
if
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.bias"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.bias"
,
"word_embeddings.weight"
,
"lm_head.weight"
,
]
):
output_size
=
state
.
shape
[
0
]
assert
output_size
%
tp_world_size
==
0
block_size
=
output_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
0
]
==
block_size
if
match_suffix
(
state_name
,
"lm_head.weight"
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
else
:
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.weight"
,
"mlp.dense_4h_to_h.weight"
,
"lm_head.weight"
,
]
):
input_size
=
state
.
shape
[
1
]
assert
input_size
%
tp_world_size
==
0
block_size
=
input_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
1
]
==
block_size
if
match_suffix
(
state_name
,
"lm_head.weight"
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
else
:
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.bias"
,
"mlp.dense_4h_to_h.bias"
,
]
):
shards_state_dicts
[
0
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
for
tp_rank
in
range
(
1
,
tp_world_size
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
torch
.
zeros_like
(
state
)
else
:
# We duplicate parameters across tp ranks
for
tp_rank
in
range
(
tp_world_size
):
shards_state_dicts
[
tp_rank
][
"transformer."
+
state_name
]
=
state
.
detach
().
clone
()
del
state_dict
[
state_name
]
# delete key from state_dict
del
state
# delete tensor
# we save state_dict
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
zip
(
save_paths
,
shards_state_dicts
)
):
save_paths
.
append
(
save_path
)
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
save_path
.
exists
():
print
(
f
"Skipping
{
save_path
}
as it already exists"
)
else
:
torch
.
save
(
shard_state_dict
,
save_path
)
return
save_paths
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--hub-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--save-path"
,
required
=
True
,
type
=
str
)
parser
.
add_argument
(
"--world-size"
,
required
=
True
,
type
=
int
)
args
=
parser
.
parse_args
()
prepare_weights
(
Path
(
args
.
hub_path
),
Path
(
args
.
save_path
),
args
.
world_size
)
server/bloom_inference/server.py
0 → 100644
View file @
295831a4
import
asyncio
from
grpc
import
aio
from
grpc_reflection.v1alpha
import
reflection
from
pathlib
import
Path
from
typing
import
Optional
,
List
from
bloom_inference.cache
import
Cache
from
bloom_inference.model
import
BLOOM
,
Batch
,
BLOOMSharded
from
bloom_inference.pb
import
generate_pb2_grpc
,
generate_pb2
class
TextGeneration
(
generate_pb2_grpc
.
TextGenerationServicer
):
def
__init__
(
self
,
model
:
BLOOM
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
self
.
cache
=
cache
self
.
model
=
model
self
.
server_urls
=
server_urls
async
def
ServiceDiscovery
(
self
,
request
,
context
):
return
generate_pb2
.
ServiceDiscoveryResponse
(
urls
=
self
.
server_urls
)
async
def
ClearCache
(
self
,
request
,
context
):
self
.
cache
.
clear
()
return
generate_pb2
.
Empty
()
async
def
Generate
(
self
,
request
,
context
):
batch
=
Batch
.
from_batch_pb
(
request
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
finished_generations
,
cache_entry
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
cache_entry
)
return
generate_pb2
.
Response
(
finished
=
[
finished_generation
.
to_pb
()
for
finished_generation
in
finished_generations
],
cache_entry
=
cache_entry
.
to_pb
()
if
cache_entry
else
None
,
)
async
def
GenerateWithCache
(
self
,
request
,
context
):
batch
=
Batch
.
from_batch_cached_pb
(
request
,
self
.
cache
)
finished_generations
,
cache_entry
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
cache_entry
)
return
generate_pb2
.
Response
(
finished
=
[
finished_generation
.
to_pb
()
for
finished_generation
in
finished_generations
],
cache_entry
=
cache_entry
.
to_pb
()
if
cache_entry
else
None
,
)
def
serve
(
model_name
,
sharded
,
shard_directory
):
async
def
serve_inner
(
model_name
:
str
,
sharded
:
bool
=
False
,
shard_directory
:
Optional
[
Path
]
=
None
,
):
unix_socket_template
=
"unix:///tmp/bloom-inference-{}"
if
sharded
:
if
shard_directory
is
None
:
raise
ValueError
(
"shard_directory must be set when sharded is True"
)
model
=
BLOOMSharded
(
model_name
,
shard_directory
)
server_urls
=
[
unix_socket_template
.
format
(
rank
)
for
rank
in
range
(
model
.
world_size
)
]
local_url
=
unix_socket_template
.
format
(
model
.
rank
)
else
:
model
=
BLOOM
(
model_name
)
local_url
=
unix_socket_template
.
format
(
0
)
server_urls
=
[
local_url
]
server
=
aio
.
server
()
generate_pb2_grpc
.
add_TextGenerationServicer_to_server
(
TextGeneration
(
model
,
Cache
(),
server_urls
),
server
)
SERVICE_NAMES
=
(
generate_pb2
.
DESCRIPTOR
.
services_by_name
[
"TextGeneration"
].
full_name
,
reflection
.
SERVICE_NAME
,
)
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
server
.
add_insecure_port
(
local_url
)
await
server
.
start
()
print
(
"Server started at {}"
.
format
(
local_url
))
await
server
.
wait_for_termination
()
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
shard_directory
))
if
__name__
==
"__main__"
:
serve
(
"bigscience/bloom-560m"
,
True
,
Path
(
"/tmp/models"
))
server/bloom_inference/shard_model.py
0 → 100644
View file @
295831a4
from
pathlib
import
Path
import
torch
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
def
match_suffix
(
text
,
suffix
):
return
text
[
-
len
(
suffix
)
:]
==
suffix
def
shard_model
(
model_name
:
str
,
path
:
Path
,
tp_world_size
:
int
,
dtype
:
torch
.
dtype
):
"""BLOOM specific sharding mechanism"""
save_paths
=
[
path
/
f
"
{
model_name
}
_tp-rank-
{
tp_rank
}
-of-
{
tp_world_size
}
.pty"
for
tp_rank
in
range
(
tp_world_size
)
]
if
all
(
save_path
.
exists
()
for
save_path
in
save_paths
):
print
(
"Loading already cached values"
)
return
save_paths
model
:
nn
.
Module
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
local_files_only
=
True
)
shards_state_dicts
=
[{}
for
_
in
range
(
tp_world_size
)]
state_dict
=
model
.
state_dict
()
keys
=
list
(
state_dict
.
keys
())
for
state_name
in
keys
:
print
(
state_name
)
state
=
state_dict
[
state_name
]
if
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.query_key_value.weight"
,
"self_attention.query_key_value.bias"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_h_to_4h.bias"
,
"transformer.word_embeddings.weight"
,
"lm_head.weight"
,
]
):
output_size
=
state
.
shape
[
0
]
assert
output_size
%
tp_world_size
==
0
block_size
=
output_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
0
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
0
]
==
block_size
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.weight"
,
"mlp.dense_4h_to_h.weight"
,
"lm_head.weight"
,
]
):
input_size
=
state
.
shape
[
1
]
assert
input_size
%
tp_world_size
==
0
block_size
=
input_size
//
tp_world_size
sharded_weights
=
torch
.
split
(
state
,
block_size
,
dim
=
1
)
assert
len
(
sharded_weights
)
==
tp_world_size
for
tp_rank
,
shard
in
enumerate
(
sharded_weights
):
assert
shard
.
shape
[
1
]
==
block_size
shards_state_dicts
[
tp_rank
][
state_name
]
=
shard
.
detach
().
clone
()
elif
any
(
match_suffix
(
state_name
,
candidate
)
for
candidate
in
[
"self_attention.dense.bias"
,
"mlp.dense_4h_to_h.bias"
,
]
):
shards_state_dicts
[
0
][
state_name
]
=
state
.
detach
().
clone
()
for
tp_rank
in
range
(
1
,
tp_world_size
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
torch
.
zeros_like
(
state
)
else
:
# We duplicate parameters across tp ranks
for
tp_rank
in
range
(
tp_world_size
):
shards_state_dicts
[
tp_rank
][
state_name
]
=
state
.
detach
().
clone
()
del
state_dict
[
state_name
]
# delete key from state_dict
del
state
# delete tensor
# we save state_dict
for
tp_rank
,
(
save_path
,
shard_state_dict
)
in
enumerate
(
zip
(
save_paths
,
shards_state_dicts
)
):
save_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
shard_state_dict
,
save_path
)
save_paths
.
append
(
save_path
)
return
save_paths
if
__name__
==
"__main__"
:
model_name
=
"bigscience/bloom"
save_path
=
Path
(
"/data/shards"
)
tp_world_size
=
8
dtype
=
torch
.
bfloat16
shard_model
(
model_name
,
save_path
,
tp_world_size
=
tp_world_size
,
dtype
=
dtype
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment