Unverified Commit 3fef90d5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(clients): Python client (#103)

parent 0e9ed1a8
...@@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import ( ...@@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )
......
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
......
...@@ -4,7 +4,7 @@ import torch.distributed ...@@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>" FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>" FIM_MIDDLE = "<fim-middle>"
......
...@@ -5,10 +5,15 @@ from opentelemetry import trace ...@@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 GeneratedText,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
...@@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
......
...@@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import ( ...@@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import Seq2SeqLM from text_generation_server.models import Seq2SeqLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )
......
...@@ -6,8 +6,8 @@ from typing import List, Optional ...@@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):
......
...@@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection ...@@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from text_generation.cache import Cache from text_generation_server.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
......
from text_generation.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
download_weights, download_weights,
...@@ -8,7 +8,7 @@ from text_generation.utils.hub import ( ...@@ -8,7 +8,7 @@ from text_generation.utils.hub import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
) )
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
Greedy, Greedy,
NextTokenChooser, NextTokenChooser,
Sampling, Sampling,
......
...@@ -11,9 +11,9 @@ from transformers import ( ...@@ -11,9 +11,9 @@ from transformers import (
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling: class Sampling:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment