Unverified Commit 3fef90d5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(clients): Python client (#103)

parent 0e9ed1a8
......@@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
......
......@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
......
......@@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
......
......@@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
......@@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
......@@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []
......
......@@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import Seq2SeqLM
from text_generation.utils import (
from text_generation_server.models import Seq2SeqLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
......
......@@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):
......
......@@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
......
from text_generation.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import (
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
......@@ -8,7 +8,7 @@ from text_generation.utils.hub import (
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation.utils.tokens import (
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
......
......@@ -11,9 +11,9 @@ from transformers import (
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment