Unverified Commit 3fef90d5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(clients): Python client (#103)

parent 0e9ed1a8
......@@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("temperature must be strictly positive")]
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("repetition_penalty must be strictly positive")]
#[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")]
#[error("`top_p` must be > 0.0 and <= 1.0")]
TopP,
#[error("top_k must be strictly positive")]
#[error("`top_k` must be strictly positive")]
TopK,
#[error("max_new_tokens must be strictly positive")]
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")]
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize),
#[error("inputs cannot be empty")]
#[error("`inputs` cannot be empty")]
EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")]
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize),
#[error("tokenizer error {0}")]
Tokenizer(String),
......
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class
......
......@@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
......@@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
\ No newline at end of file
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
\ No newline at end of file
[tool.poetry]
name = "text-generation"
name = "text-generation-server"
version = "0.3.2"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app'
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
......
import pytest
from text_generation.pb import generate_pb2
from text_generation_server.pb import generate_pb2
@pytest.fixture
......
......@@ -4,9 +4,9 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")
......
......@@ -4,8 +4,8 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
......
import pytest
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session")
......
......@@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
......
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files
from text_generation_server.utils.convert import convert_files
def test_convert_files():
......
import pytest
from text_generation.utils.hub import (
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
......
from text_generation.utils.tokens import (
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
......
from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)
......
......@@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
......@@ -42,7 +42,7 @@ def serve(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
......@@ -68,7 +68,7 @@ def download_weights(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
......
......@@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig
from typing import Optional
from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
__all__ = [
"Model",
......
......@@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
......
......@@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
......
......@@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment