Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3fef90d5
Unverified
Commit
3fef90d5
authored
Mar 07, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 07, 2023
Browse files
feat(clients): Python client (#103)
parent
0e9ed1a8
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
67 additions
and
58 deletions
+67
-58
router/src/validation.rs
router/src/validation.rs
+9
-9
server/.gitignore
server/.gitignore
+2
-2
server/Makefile
server/Makefile
+5
-5
server/pyproject.toml
server/pyproject.toml
+2
-2
server/tests/conftest.py
server/tests/conftest.py
+1
-1
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+3
-3
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+2
-2
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+3
-3
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+2
-2
server/tests/utils/test_convert.py
server/tests/utils/test_convert.py
+6
-2
server/tests/utils/test_hub.py
server/tests/utils/test_hub.py
+1
-1
server/tests/utils/test_tokens.py
server/tests/utils/test_tokens.py
+1
-1
server/text_generation_server/__init__.py
server/text_generation_server/__init__.py
+0
-0
server/text_generation_server/cache.py
server/text_generation_server/cache.py
+1
-1
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+4
-4
server/text_generation_server/interceptor.py
server/text_generation_server/interceptor.py
+0
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+8
-8
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+4
-4
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+9
-4
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+4
-4
No files found.
router/src/validation.rs
View file @
3fef90d5
...
@@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
...
@@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error,
Debug)]
#[derive(Error,
Debug)]
pub
enum
ValidationError
{
pub
enum
ValidationError
{
#[error(
"temperature must be strictly positive"
)]
#[error(
"
`
temperature
`
must be strictly positive"
)]
Temperature
,
Temperature
,
#[error(
"repetition_penalty must be strictly positive"
)]
#[error(
"
`
repetition_penalty
`
must be strictly positive"
)]
RepetitionPenalty
,
RepetitionPenalty
,
#[error(
"top_p must be > 0.0 and <= 1.0"
)]
#[error(
"
`
top_p
`
must be > 0.0 and <= 1.0"
)]
TopP
,
TopP
,
#[error(
"top_k must be strictly positive"
)]
#[error(
"
`
top_k
`
must be strictly positive"
)]
TopK
,
TopK
,
#[error(
"max_new_tokens must be strictly positive"
)]
#[error(
"
`
max_new_tokens
`
must be strictly positive"
)]
MaxNewTokens
,
MaxNewTokens
,
#[error(
"input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens"
)]
#[error(
"
`
input
s`
tokens +
`
max_new_tokens
`
must be <= {0}. Given: {1}
`
input
s`
tokens and {2}
`
max_new_tokens
`
"
)]
MaxTotalTokens
(
usize
,
usize
,
u32
),
MaxTotalTokens
(
usize
,
usize
,
u32
),
#[error(
"inputs must have less than {0} tokens. Given: {1}"
)]
#[error(
"
`
inputs
`
must have less than {0} tokens. Given: {1}"
)]
InputLength
(
usize
,
usize
),
InputLength
(
usize
,
usize
),
#[error(
"inputs cannot be empty"
)]
#[error(
"
`
inputs
`
cannot be empty"
)]
EmptyInput
,
EmptyInput
,
#[error(
"stop supports up to {0} stop sequences. Given: {1}"
)]
#[error(
"
`
stop
`
supports up to {0} stop sequences. Given: {1}"
)]
StopSequence
(
usize
,
usize
),
StopSequence
(
usize
,
usize
),
#[error(
"tokenizer error {0}"
)]
#[error(
"tokenizer error {0}"
)]
Tokenizer
(
String
),
Tokenizer
(
String
),
...
...
server/.gitignore
View file @
3fef90d5
# Byte-compiled / optimized / DLL files
# Byte-compiled / optimized / DLL files
__pycache__/
__pycache__/
text_generation/__pycache__/
text_generation
_server
/__pycache__/
text_generation/pb/__pycache__/
text_generation
_server
/pb/__pycache__/
*.py[cod]
*.py[cod]
*$py.class
*$py.class
...
...
server/Makefile
View file @
3fef90d5
...
@@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
...
@@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server
:
gen-server
:
# Compile protos
# Compile protos
pip
install
grpcio-tools
==
1.51.1
--no-cache-dir
pip
install
grpcio-tools
==
1.51.1
--no-cache-dir
mkdir
text_generation/pb
||
true
mkdir
text_generation
_server
/pb
||
true
python
-m
grpc_tools.protoc
-I
../proto
--python_out
=
text_generation/pb
--grpc_python_out
=
text_generation/pb ../proto/generate.proto
python
-m
grpc_tools.protoc
-I
../proto
--python_out
=
text_generation
_server
/pb
--grpc_python_out
=
text_generation
_server
/pb ../proto/generate.proto
find text_generation/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
find text_generation
_server
/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
touch
text_generation/pb/__init__.py
touch
text_generation
_server
/pb/__init__.py
install-transformers
:
install-transformers
:
# Install specific version of transformers with custom cuda kernels
# Install specific version of transformers with custom cuda kernels
...
@@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
...
@@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip
install
-e
.
--no-cache-dir
pip
install
-e
.
--no-cache-dir
run-dev
:
run-dev
:
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation/cli.py serve bigscience/bloom-560m
--sharded
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
\ No newline at end of file
\ No newline at end of file
server/pyproject.toml
View file @
3fef90d5
[tool.poetry]
[tool.poetry]
name
=
"text-generation"
name
=
"text-generation
-server
"
version
=
"0.3.2"
version
=
"0.3.2"
description
=
"Text Generation Inference Python gRPC Server"
description
=
"Text Generation Inference Python gRPC Server"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
[tool.poetry.scripts]
[tool.poetry.scripts]
text-generation-server
=
'text_generation.cli:app'
text-generation-server
=
'text_generation
_server
.cli:app'
[tool.poetry.dependencies]
[tool.poetry.dependencies]
python
=
"^3.9"
python
=
"^3.9"
...
...
server/tests/conftest.py
View file @
3fef90d5
import
pytest
import
pytest
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
@
pytest
.
fixture
@
pytest
.
fixture
...
...
server/tests/models/test_bloom.py
View file @
3fef90d5
...
@@ -4,9 +4,9 @@ import torch
...
@@ -4,9 +4,9 @@ import torch
from
copy
import
copy
from
copy
import
copy
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation
_server
.models.causal_lm
import
CausalLMBatch
from
text_generation.models.bloom
import
BloomCausalLMBatch
,
BLOOM
from
text_generation
_server
.models.bloom
import
BloomCausalLMBatch
,
BLOOM
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
...
server/tests/models/test_causal_lm.py
View file @
3fef90d5
...
@@ -4,8 +4,8 @@ import torch
...
@@ -4,8 +4,8 @@ import torch
from
copy
import
copy
from
copy
import
copy
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLM
,
CausalLMBatch
from
text_generation
_server
.models.causal_lm
import
CausalLM
,
CausalLMBatch
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
...
server/tests/models/test_santacoder.py
View file @
3fef90d5
import
pytest
import
pytest
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation
_server
.models.causal_lm
import
CausalLMBatch
from
text_generation.models.santacoder
import
SantaCoder
from
text_generation
_server
.models.santacoder
import
SantaCoder
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
...
server/tests/models/test_seq2seq_lm.py
View file @
3fef90d5
...
@@ -5,8 +5,8 @@ from copy import copy
...
@@ -5,8 +5,8 @@ from copy import copy
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
,
Seq2SeqLMBatch
from
text_generation
_server
.models.seq2seq_lm
import
Seq2SeqLM
,
Seq2SeqLMBatch
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
...
server/tests/utils/test_convert.py
View file @
3fef90d5
from
text_generation.utils.hub
import
download_weights
,
weight_hub_files
,
weight_files
from
text_generation_server.utils.hub
import
(
download_weights
,
weight_hub_files
,
weight_files
,
)
from
text_generation.utils.convert
import
convert_files
from
text_generation
_server
.utils.convert
import
convert_files
def
test_convert_files
():
def
test_convert_files
():
...
...
server/tests/utils/test_hub.py
View file @
3fef90d5
import
pytest
import
pytest
from
text_generation.utils.hub
import
(
from
text_generation
_server
.utils.hub
import
(
weight_hub_files
,
weight_hub_files
,
download_weights
,
download_weights
,
weight_files
,
weight_files
,
...
...
server/tests/utils/test_tokens.py
View file @
3fef90d5
from
text_generation.utils.tokens
import
(
from
text_generation
_server
.utils.tokens
import
(
StopSequenceCriteria
,
StopSequenceCriteria
,
StoppingCriteria
,
StoppingCriteria
,
FinishReason
,
FinishReason
,
...
...
server/text_generation/__init__.py
→
server/text_generation
_server
/__init__.py
View file @
3fef90d5
File moved
server/text_generation/cache.py
→
server/text_generation
_server
/cache.py
View file @
3fef90d5
from
typing
import
Dict
,
Optional
,
TypeVar
from
typing
import
Dict
,
Optional
,
TypeVar
from
text_generation.models.types
import
Batch
from
text_generation
_server
.models.types
import
Batch
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
...
...
server/text_generation/cli.py
→
server/text_generation
_server
/cli.py
View file @
3fef90d5
...
@@ -6,8 +6,8 @@ from pathlib import Path
...
@@ -6,8 +6,8 @@ from pathlib import Path
from
loguru
import
logger
from
loguru
import
logger
from
typing
import
Optional
from
typing
import
Optional
from
text_generation
import
server
,
utils
from
text_generation
_server
import
server
,
utils
from
text_generation.tracing
import
setup_tracing
from
text_generation
_server
.tracing
import
setup_tracing
app
=
typer
.
Typer
()
app
=
typer
.
Typer
()
...
@@ -42,7 +42,7 @@ def serve(
...
@@ -42,7 +42,7 @@ def serve(
logger
.
add
(
logger
.
add
(
sys
.
stdout
,
sys
.
stdout
,
format
=
"{message}"
,
format
=
"{message}"
,
filter
=
"text_generation"
,
filter
=
"text_generation
_server
"
,
level
=
logger_level
,
level
=
logger_level
,
serialize
=
json_output
,
serialize
=
json_output
,
backtrace
=
True
,
backtrace
=
True
,
...
@@ -68,7 +68,7 @@ def download_weights(
...
@@ -68,7 +68,7 @@ def download_weights(
logger
.
add
(
logger
.
add
(
sys
.
stdout
,
sys
.
stdout
,
format
=
"{message}"
,
format
=
"{message}"
,
filter
=
"text_generation"
,
filter
=
"text_generation
_server
"
,
level
=
logger_level
,
level
=
logger_level
,
serialize
=
json_output
,
serialize
=
json_output
,
backtrace
=
True
,
backtrace
=
True
,
...
...
server/text_generation/interceptor.py
→
server/text_generation
_server
/interceptor.py
View file @
3fef90d5
File moved
server/text_generation/models/__init__.py
→
server/text_generation
_server
/models/__init__.py
View file @
3fef90d5
...
@@ -3,14 +3,14 @@ import torch
...
@@ -3,14 +3,14 @@ import torch
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
typing
import
Optional
from
typing
import
Optional
from
text_generation.models.model
import
Model
from
text_generation
_server
.models.model
import
Model
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation
_server
.models.causal_lm
import
CausalLM
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation
_server
.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation
_server
.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation
_server
.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation.models.santacoder
import
SantaCoder
from
text_generation
_server
.models.santacoder
import
SantaCoder
from
text_generation.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
from
text_generation
_server
.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
from
text_generation.models.t5
import
T5Sharded
from
text_generation
_server
.models.t5
import
T5Sharded
__all__
=
[
__all__
=
[
"Model"
,
"Model"
,
...
...
server/text_generation/models/bloom.py
→
server/text_generation
_server
/models/bloom.py
View file @
3fef90d5
...
@@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
...
@@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
text_generation.models
import
CausalLM
from
text_generation
_server
.models
import
CausalLM
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation
_server
.models.causal_lm
import
CausalLMBatch
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.utils
import
(
from
text_generation
_server
.utils
import
(
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
)
)
...
...
server/text_generation/models/causal_lm.py
→
server/text_generation
_server
/models/causal_lm.py
View file @
3fef90d5
...
@@ -5,10 +5,15 @@ from opentelemetry import trace
...
@@ -5,10 +5,15 @@ from opentelemetry import trace
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
text_generation.models
import
Model
from
text_generation_server.models
import
Model
from
text_generation.models.types
import
Batch
,
PrefillTokens
,
Generation
,
GeneratedText
from
text_generation_server.models.types
import
(
from
text_generation.pb
import
generate_pb2
Batch
,
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
PrefillTokens
,
Generation
,
GeneratedText
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
...
...
server/text_generation/models/galactica.py
→
server/text_generation
_server
/models/galactica.py
View file @
3fef90d5
...
@@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
...
@@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
text_generation.models
import
CausalLM
from
text_generation
_server
.models
import
CausalLM
from
text_generation.pb
import
generate_pb2
from
text_generation
_server
.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation
_server
.models.causal_lm
import
CausalLMBatch
from
text_generation.utils
import
(
from
text_generation
_server
.utils
import
(
NextTokenChooser
,
NextTokenChooser
,
StoppingCriteria
,
StoppingCriteria
,
initialize_torch_distributed
,
initialize_torch_distributed
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment