Unverified Commit 3fef90d5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(clients): Python client (#103)

parent 0e9ed1a8
...@@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest { ...@@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("repetition_penalty must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")] #[error("`top_p` must be > 0.0 and <= 1.0")]
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("`top_k` must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be strictly positive")] #[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens, MaxNewTokens,
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")] #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32), MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")] #[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize), InputLength(usize, usize),
#[error("inputs cannot be empty")] #[error("`inputs` cannot be empty")]
EmptyInput, EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")] #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
......
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
text_generation/__pycache__/ text_generation_server/__pycache__/
text_generation/pb/__pycache__/ text_generation_server/pb/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
......
...@@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e ...@@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers: install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
...@@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers ...@@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
\ No newline at end of file \ No newline at end of file
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation-server"
version = "0.3.2" version = "0.3.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app' text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
......
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@pytest.fixture @pytest.fixture
......
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
...@@ -4,8 +4,8 @@ import torch ...@@ -4,8 +4,8 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
...@@ -5,8 +5,8 @@ from copy import copy ...@@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files from text_generation_server.utils.convert import convert_files
def test_convert_files(): def test_convert_files():
......
import pytest import pytest
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_hub_files, weight_hub_files,
download_weights, download_weights,
weight_files, weight_files,
......
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,
......
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
......
...@@ -6,8 +6,8 @@ from pathlib import Path ...@@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation import server, utils from text_generation_server import server, utils
from text_generation.tracing import setup_tracing from text_generation_server.tracing import setup_tracing
app = typer.Typer() app = typer.Typer()
...@@ -42,7 +42,7 @@ def serve( ...@@ -42,7 +42,7 @@ def serve(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,
...@@ -68,7 +68,7 @@ def download_weights( ...@@ -68,7 +68,7 @@ def download_weights(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,
......
...@@ -3,14 +3,14 @@ import torch ...@@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig from transformers import AutoConfig
from typing import Optional from typing import Optional
from text_generation.models.model import Model from text_generation_server.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
__all__ = [ __all__ = [
"Model", "Model",
......
...@@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import ( ...@@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )
......
...@@ -5,10 +5,15 @@ from opentelemetry import trace ...@@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 Batch,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
......
...@@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import ( ...@@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment