Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
15511edc
Unverified
Commit
15511edc
authored
Jan 20, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 20, 2023
Browse files
feat(server): Support SantaCoder (#26)
parent
f7ac3949
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
320 additions
and
78 deletions
+320
-78
README.md
README.md
+1
-0
launcher/src/main.rs
launcher/src/main.rs
+3
-7
router/src/batcher.rs
router/src/batcher.rs
+4
-3
router/src/validation.rs
router/src/validation.rs
+8
-8
server/tests/conftest.py
server/tests/conftest.py
+0
-23
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+11
-5
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+13
-5
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+93
-0
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+16
-5
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+21
-1
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+10
-9
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+14
-5
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+23
-3
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+87
-0
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+9
-2
server/text_generation/models/types.py
server/text_generation/models/types.py
+4
-1
server/text_generation/utils.py
server/text_generation/utils.py
+3
-1
No files found.
README.md
View file @
15511edc
...
...
@@ -25,6 +25,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
[
BLOOMZ
](
https://huggingface.co/bigscience/bloomz
)
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
Other models are supported on a best effort basis using:
...
...
launcher/src/main.rs
View file @
15511edc
use
clap
::
Parser
;
use
serde_json
::
Value
;
use
std
::
env
;
use
std
::
io
::{
BufRead
,
BufReader
,
Read
};
use
std
::
path
::
Path
;
...
...
@@ -11,7 +12,6 @@ use std::thread;
use
std
::
thread
::
sleep
;
use
std
::
time
::{
Duration
,
Instant
};
use
std
::{
fs
,
io
};
use
serde_json
::
Value
;
use
subprocess
::{
Popen
,
PopenConfig
,
PopenError
,
Redirection
};
/// App Configuration
...
...
@@ -299,16 +299,12 @@ fn shard_manager(
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if
let
Ok
(
huggingface_hub_cache
)
=
env
::
var
(
"HUGGINGFACE_HUB_CACHE"
)
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
(),
));
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if
let
Ok
(
cuda_visible_devices
)
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
{
env
.push
((
"CUDA_VISIBLE_DEVICES"
.into
(),
cuda_visible_devices
.into
(),
));
env
.push
((
"CUDA_VISIBLE_DEVICES"
.into
(),
cuda_visible_devices
.into
()));
};
// Start process
...
...
router/src/batcher.rs
View file @
15511edc
...
...
@@ -74,9 +74,10 @@ impl Batcher {
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
response_rx
.await
.unwrap
()
.map_err
(
|
err
|
InferError
::
GenerationError
(
err
.to_string
())
)
response_rx
.await
.unwrap
()
.map_err
(|
err
|
InferError
::
GenerationError
(
err
.to_string
()))
}
}
...
...
router/src/validation.rs
View file @
15511edc
...
...
@@ -94,7 +94,9 @@ fn validation_worker(
)
{
// Loop over requests
while
let
Some
((
request
,
response_tx
))
=
receiver
.blocking_recv
()
{
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
))
.unwrap_or
(())
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
))
.unwrap_or
(())
}
}
...
...
@@ -117,8 +119,9 @@ fn validate(
}
if
request
.parameters.stop
.len
()
>
MAX_STOP_SEQUENCES
{
return
Err
(
ValidationError
::
StopSequence
(
MAX_STOP_SEQUENCES
,
request
.parameters.stop
.len
(),
))
MAX_STOP_SEQUENCES
,
request
.parameters.stop
.len
(),
));
}
// Get the number of tokens in the input
...
...
@@ -127,14 +130,11 @@ fn validate(
let
input_length
=
inputs
.len
();
if
input_length
>
max_input_length
{
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
,
))
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
))
}
else
{
Ok
((
input_length
,
request
))
}
}
,
}
Err
(
err
)
=>
Err
(
ValidationError
::
Tokenizer
(
err
.to_string
())),
}
}
...
...
server/tests/conftest.py
View file @
15511edc
import
pytest
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
...
...
@@ -18,24 +16,3 @@ def default_pb_parameters():
@
pytest
.
fixture
def
default_pb_stop_parameters
():
return
generate_pb2
.
StoppingCriteriaParameters
(
stop_sequences
=
[],
max_new_tokens
=
10
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
bloom_560m_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
"bigscience/bloom-560m"
,
padding_side
=
"left"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gpt2_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
,
padding_side
=
"left"
)
tokenizer
.
pad_token_id
=
50256
return
tokenizer
@
pytest
.
fixture
(
scope
=
"session"
)
def
mt0_small_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bigscience/mt0-small"
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
0
return
tokenizer
server/tests/models/test_bloom.py
View file @
15511edc
...
...
@@ -2,12 +2,23 @@ import pytest
import
torch
from
copy
import
copy
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation.models.bloom
import
BloomCausalLMBatch
,
BLOOM
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_bloom
():
return
BLOOM
(
"bigscience/bloom-560m"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
bloom_560m_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
"bigscience/bloom-560m"
,
padding_side
=
"left"
)
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
...
...
@@ -44,11 +55,6 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_bloom
():
return
BLOOM
(
"bigscience/bloom-560m"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_bloom_batch
):
batch
=
default_bloom_batch
...
...
server/tests/models/test_causal_lm.py
View file @
15511edc
...
...
@@ -2,11 +2,24 @@ import pytest
import
torch
from
copy
import
copy
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLM
,
CausalLMBatch
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_causal_lm
():
return
CausalLM
(
"gpt2"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gpt2_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
,
padding_side
=
"left"
)
tokenizer
.
pad_token_id
=
50256
return
tokenizer
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
...
...
@@ -39,11 +52,6 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
return
CausalLMBatch
.
from_pb
(
batch_pb
,
gpt2_tokenizer
,
torch
.
device
(
"cpu"
))
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_causal_lm
():
return
CausalLM
(
"gpt2"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_causal_lm_batch
):
batch
=
default_causal_lm_batch
...
...
server/tests/models/test_santacoder.py
0 → 100644
View file @
15511edc
import
pytest
from
text_generation.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation.models.santacoder
import
SantaCoder
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_santacoder
():
return
SantaCoder
(
"bigcode/santacoder"
)
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"def"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
@
pytest
.
fixture
def
default_pb_batch
(
default_pb_request
):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_pb_request
],
size
=
1
)
@
pytest
.
fixture
def
default_fim_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
,
input_length
=
5
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
@
pytest
.
fixture
def
default_fim_pb_batch
(
default_fim_pb_request
):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_fim_pb_request
],
size
=
1
)
def
test_santacoder_generate_token_completion
(
default_santacoder
,
default_pb_batch
):
batch
=
CausalLMBatch
.
from_pb
(
default_pb_batch
,
default_santacoder
.
tokenizer
,
default_santacoder
.
device
)
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output_text
==
"def test_get_all_users_with_"
assert
generated_texts
[
0
].
request
==
batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
(
generated_texts
[
0
].
generated_tokens
==
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
def
test_fim_santacoder_generate_token_completion
(
default_santacoder
,
default_fim_pb_batch
):
batch
=
CausalLMBatch
.
from_pb
(
default_fim_pb_batch
,
default_santacoder
.
tokenizer
,
default_santacoder
.
device
)
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
(
generated_texts
[
0
].
output_text
==
"""<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
)
assert
generated_texts
[
0
].
request
==
batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
(
generated_texts
[
0
].
generated_tokens
==
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_seq2seq_lm.py
View file @
15511edc
...
...
@@ -3,10 +3,26 @@ import torch
from
copy
import
copy
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
,
Seq2SeqLMBatch
@
pytest
.
fixture
(
scope
=
"session"
)
def
mt0_small_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bigscience/mt0-small"
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
0
return
tokenizer
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_seq2seq_lm
():
return
Seq2SeqLM
(
"bigscience/mt0-small"
)
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
...
...
@@ -41,11 +57,6 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
return
Seq2SeqLMBatch
.
from_pb
(
batch_pb
,
mt0_small_tokenizer
,
torch
.
device
(
"cpu"
))
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_seq2seq_lm
():
return
Seq2SeqLM
(
"bigscience/mt0-small"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_seq2seq_lm_batch
):
batch
=
default_seq2seq_lm_batch
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
...
...
server/text_generation/models/__init__.py
View file @
15511edc
import
torch
from
text_generation.models.model
import
Model
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation.models.santacoder
import
SantaCoder
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
,
"SantaCoder"
,
"get_model"
,
]
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
,
"get_model"
]
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
...
...
@@ -18,6 +36,8 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
return
GalacticaSharded
(
model_name
,
quantize
=
quantize
)
else
:
return
Galactica
(
model_name
,
quantize
=
quantize
)
elif
"santacoder"
in
model_name
:
return
SantaCoder
(
model_name
,
quantize
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
...
...
server/text_generation/models/bloom.py
View file @
15511edc
...
...
@@ -5,7 +5,12 @@ from typing import List, Optional, Type
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedTokenizerBase
from
transformers
import
(
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedTokenizerBase
,
)
from
transformers.models.bloom.parallel_layers
import
(
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
...
...
@@ -34,7 +39,10 @@ torch.manual_seed(0)
class
BloomCausalLMBatch
(
CausalLMBatch
):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"CausalLMBatch"
:
batch
=
super
(
BloomCausalLMBatch
,
cls
).
from_pb
(
pb
=
pb
,
tokenizer
=
tokenizer
,
device
=
device
...
...
@@ -70,13 +78,6 @@ class BLOOMSharded(BLOOM):
)
config
.
pad_token_id
=
3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Only download weights for small models
if
self
.
master
and
model_name
==
"bigscience/bloom-560m"
:
download_weights
(
model_name
,
extension
=
".safetensors"
)
...
...
server/text_generation/models/causal_lm.py
View file @
15511edc
...
...
@@ -47,7 +47,10 @@ class CausalLMBatch(Batch):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"CausalLMBatch"
:
inputs
=
[]
next_token_choosers
=
[]
...
...
@@ -71,6 +74,7 @@ class CausalLMBatch(Batch):
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
...
@@ -253,6 +257,11 @@ class CausalLM(Model):
def
batch_type
(
self
)
->
Type
[
CausalLMBatch
]:
return
CausalLMBatch
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
True
,
cleanup_tokenization_spaces
=
False
)
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
...
...
@@ -338,11 +347,11 @@ class CausalLM(Model):
),
)
if
stop
:
# Decode all tokens
output_text
=
self
.
tokenizer
.
decode
(
all_input_ids
.
squeeze
(
-
1
),
skip_special_tokens
=
True
,
cleanup_tokenization_spaces
=
False
# Decode generated tokens
generated_text
=
self
.
decode
(
all_input_ids
[
-
stopping_criteria
.
current_tokens
:,
0
]
)
output_text
=
request
.
inputs
+
generated_text
# Slice with input_length to remove padding
token_ids
=
all_input_ids
[
-
new_input_length
:]
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
...
...
server/text_generation/models/galactica.py
View file @
15511edc
...
...
@@ -6,7 +6,12 @@ from typing import List, Optional, Type
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedTokenizerBase
from
transformers
import
(
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedTokenizerBase
,
)
from
transformers.models.opt.parallel_layers
import
(
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
...
...
@@ -82,7 +87,10 @@ def escape_custom_split_sequence(text):
class
GalacticaCausalLMBatch
(
CausalLMBatch
):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"GalacticaCausalLMBatch"
:
inputs
=
[]
next_token_choosers
=
[]
...
...
@@ -99,8 +107,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
# Tokenize batch
pad_to_multiple_of
=
8
if
device
.
type
==
"cuda"
else
None
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
...
@@ -124,6 +138,12 @@ class Galactica(CausalLM):
def
batch_type
(
self
)
->
Type
[
CausalLMBatch
]:
return
GalacticaCausalLMBatch
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
class
GalacticaSharded
(
Galactica
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
...
...
server/text_generation/models/santacoder.py
0 → 100644
View file @
15511edc
import
torch
import
torch.distributed
from
typing
import
Optional
,
List
,
Tuple
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
text_generation.models
import
CausalLM
FIM_PREFIX
=
"<fim-prefix>"
FIM_MIDDLE
=
"<fim-middle>"
FIM_SUFFIX
=
"<fim-suffix>"
FIM_PAD
=
"<fim-pad>"
EOD
=
"<|endoftext|>"
class
SantaCoder
(
CausalLM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
else
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
EOD
,
FIM_PREFIX
,
FIM_MIDDLE
,
FIM_SUFFIX
,
FIM_PAD
,
],
"pad_token"
:
EOD
,
}
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
trust_remote_code
=
True
,
# required
).
eval
()
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
# this leads to position_ids being wrong
input_length
=
input_ids
.
shape
[
-
1
]
past_key_values_length
=
(
0
if
past_key_values
is
None
else
past_key_values
[
0
][
0
].
shape
[
-
1
]
)
position_ids
=
torch
.
arange
(
past_key_values_length
,
input_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
).
view
(
1
,
input_length
)
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
position_ids
=
position_ids
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
server/text_generation/models/seq2seq_lm.py
View file @
15511edc
...
...
@@ -51,7 +51,10 @@ class Seq2SeqLMBatch(Batch):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"Seq2SeqLMBatch"
:
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs
=
[]
...
...
@@ -83,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
).
to
(
device
)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
...
...
@@ -318,6 +322,9 @@ class Seq2SeqLM(Model):
def
batch_type
(
self
)
->
Type
[
Seq2SeqLMBatch
]:
return
Seq2SeqLMBatch
def
decode
(
self
,
decoder_ids
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
decoder_ids
,
skip_special_tokens
=
True
)
def
forward
(
self
,
input_ids
,
...
...
@@ -438,7 +445,7 @@ class Seq2SeqLM(Model):
# Slice with decoder_input_length to remove padding
# Decode all tokens
token_ids
=
decoder_input_ids
[
-
new_decoder_input_length
:]
output_text
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
output_text
=
self
.
decode
(
token_ids
)
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the bos token
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
...
...
server/text_generation/models/types.py
View file @
15511edc
...
...
@@ -17,7 +17,10 @@ class Batch(ABC):
@
classmethod
@
abstractmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"Batch"
:
raise
NotImplementedError
...
...
server/text_generation/utils.py
View file @
15511edc
...
...
@@ -114,7 +114,9 @@ class StoppingCriteria:
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
StoppingCriteriaParameters
,
tokenizer
:
PreTrainedTokenizerBase
cls
,
pb
:
generate_pb2
.
StoppingCriteriaParameters
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
"StoppingCriteria"
:
stop_sequence_criterias
=
[
StopSequenceCriteria
(
sequence
)
for
sequence
in
pb
.
stop_sequences
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment