Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c503a639
Unverified
Commit
c503a639
authored
Feb 07, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 07, 2023
Browse files
feat(server): support t5 (#59)
parent
2fe5e1b3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
271 additions
and
1 deletion
+271
-1
README.md
README.md
+2
-1
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+11
-0
server/text_generation/models/t5.py
server/text_generation/models/t5.py
+258
-0
No files found.
README.md
View file @
c503a639
...
@@ -53,7 +53,8 @@ to power LLMs api-inference widgets.
...
@@ -53,7 +53,8 @@ to power LLMs api-inference widgets.
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
: use
`--revision pr/13`
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
-
[
FLAN-T5-XXL
](
https://huggingface.co/google/flan-t5-xxl
)
: use
`--revision pr/26`
Other models are supported on a best effort basis using:
Other models are supported on a best effort basis using:
...
...
server/text_generation/models/__init__.py
View file @
c503a639
...
@@ -10,14 +10,20 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM
...
@@ -10,14 +10,20 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation.models.santacoder
import
SantaCoder
from
text_generation.models.santacoder
import
SantaCoder
from
text_generation.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
from
text_generation.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
from
text_generation.models.t5
import
T5Sharded
__all__
=
[
__all__
=
[
"Model"
,
"Model"
,
"BLOOM"
,
"BLOOM"
,
"BLOOMSharded"
,
"BLOOMSharded"
,
"CausalLM"
,
"CausalLM"
,
"Galactica"
,
"GalacticaSharded"
,
"GPTNeox"
,
"GPTNeoxSharded"
,
"Seq2SeqLM"
,
"Seq2SeqLM"
,
"SantaCoder"
,
"SantaCoder"
,
"T5Sharded"
,
"get_model"
,
"get_model"
,
]
]
...
@@ -47,6 +53,11 @@ def get_model(
...
@@ -47,6 +53,11 @@ def get_model(
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
return
GPTNeox
(
model_id
,
revision
,
quantize
=
quantize
)
return
GPTNeox
(
model_id
,
revision
,
quantize
=
quantize
)
elif
config
.
model_type
==
"t5"
:
if
sharded
:
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
)
elif
model_id
.
startswith
(
"facebook/galactica"
):
elif
model_id
.
startswith
(
"facebook/galactica"
):
if
sharded
:
if
sharded
:
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
...
...
server/text_generation/models/t5.py
0 → 100644
View file @
c503a639
import
torch
import
torch.distributed
from
typing
import
List
,
Optional
,
Tuple
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
transformers
import
(
AutoTokenizer
,
AutoModelForSeq2SeqLM
,
AutoConfig
,
)
from
transformers.models.t5.parallel_layers
import
(
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
)
from
text_generation.models
import
Seq2SeqLM
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
try
:
import
bitsandbytes
as
bnb
from
bitsandbytes.nn
import
Int8Params
except
Exception
as
e
:
HAS_BITS_AND_BYTES
=
False
class
T5Sharded
(
Seq2SeqLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
tp_parallel
=
True
)
tokenizer
.
bos_token_id
=
config
.
decoder_start_token_id
# Only master download weights
if
self
.
master
:
download_weights
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForSeq2SeqLM
.
from_config
(
config
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
quantize
=
quantize
,
device
=
device
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
rank
:
int
,
world_size
:
int
,
):
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
current_parameter_tensor
=
parameters
.
get
(
name
,
None
)
slice_
=
f
.
get_slice
(
name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
isinstance
(
module
,
TensorParallelRowLinear
):
if
param_name
==
"weight"
:
size
=
slice_
.
get_shape
()[
1
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[:,
start
:
stop
]
else
:
tensor
=
slice_
[:]
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
name
==
"lm_head.weight"
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
"relative_attention_bias.weight"
in
name
:
size
=
slice_
.
get_shape
()[
1
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[:,
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
name
)
if
(
current_parameter_tensor
is
not
None
and
current_parameter_tensor
.
shape
!=
tensor
.
shape
):
raise
ValueError
(
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
tensor
=
tensor
.
contiguous
()
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
if
(
type
(
module
)
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
and
param_name
==
"weight"
):
tensor
=
Int8Params
(
tensor
,
has_fp16_weights
=
False
,
requires_grad
=
False
,
).
to
(
device
)
state
=
bnb
.
MatmulLtState
()
state
.
threshold
=
6.0
state
.
has_fp16_weights
=
False
state
.
memory_efficient_backward
=
False
state
.
use_pool
=
True
state
.
CB
=
tensor
.
CB
state
.
SCB
=
tensor
.
SCB
tensor
.
CB
=
None
tensor
.
SCB
=
None
def
replace_linear
(
state
):
def
linear
(
input
,
weight
,
bias
):
out
=
bnb
.
matmul
(
input
,
weight
,
state
=
state
,
threshold
=
state
.
threshold
,
bias
=
bias
,
)
if
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del
state
.
CB
weight
.
data
=
state
.
CxB
return
out
return
linear
module
.
linear
=
replace_linear
(
state
)
else
:
tensor
=
tensor
.
to
(
device
)
if
current_parameter_tensor
is
not
None
:
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
decoder_input_ids
,
decoder_attention_mask
:
Optional
,
encoder_last_hidden_state
:
Optional
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
]:
# Model Forward
if
past_key_values
is
not
None
:
decoder_input_ids
=
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if
encoder_last_hidden_state
is
not
None
:
encoder_last_hidden_state
=
[
encoder_last_hidden_state
]
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_outputs
=
encoder_last_hidden_state
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
# Logits are sharded, so we need to gather them
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
return
(
logits
,
outputs
.
encoder_last_hidden_state
,
outputs
.
past_key_values
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment