Unverified Commit 9af45414 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add distributed tracing (#62)

parent e520d5b3
......@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
......@@ -36,6 +37,7 @@ impl Validation {
}
/// Validate a payload and get the number of tokens in the input
#[instrument(skip_all)]
pub(crate) async fn validate(
&self,
request: GenerateRequest,
......@@ -44,7 +46,10 @@ impl Validation {
let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender.send((request, sender)).await.unwrap();
self.sender
.send((request, sender, Span::current()))
.await
.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap()
......@@ -97,10 +102,17 @@ fn validation_worker(
let mut rng = rand::thread_rng();
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
response_tx
.send(validate(request, &tokenizer, max_input_length, &mut rng))
.unwrap_or(())
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx
.send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
tracing::error!("{err}");
err
}),
)
.unwrap_or(())
})
}
}
......@@ -203,6 +215,7 @@ fn validate(
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
Span,
);
#[derive(Debug)]
......
[toolchain]
channel = "1.65.0"
channel = "1.67.0"
components = ["rustfmt", "clippy"]
\ No newline at end of file
gen-server:
# Compile protos
pip install grpcio-tools==1.49.1 --no-cache-dir
pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
......
This diff is collapsed.
......@@ -19,12 +19,15 @@ accelerate = "^0.15.0"
bitsandbytes = "^0.35.1"
safetensors = "^0.2.4"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
[tool.poetry.extras]
bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1"
grpcio-tools = "^1.51.1"
pytest = "^7.2.0"
[build-system]
......
......@@ -7,6 +7,7 @@ from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
app = typer.Typer()
......@@ -20,18 +21,8 @@ def serve(
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if sharded:
assert (
os.getenv("RANK", None) is not None
......@@ -46,6 +37,21 @@ def serve(
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
......
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
......@@ -9,6 +10,8 @@ from text_generation.models.types import Batch, PrefillTokens, Generation, Gener
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass
class CausalLMBatch(Batch):
......@@ -94,6 +97,7 @@ class CausalLMBatch(Batch):
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
......@@ -286,6 +290,7 @@ class CausalLM(Model):
)
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
......@@ -331,8 +336,9 @@ class CausalLM(Model):
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1)
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
......
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
......@@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass
class Seq2SeqLMBatch(Batch):
......@@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
......@@ -361,6 +365,7 @@ class Seq2SeqLM(Model):
outputs.past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
......@@ -418,7 +423,7 @@ class Seq2SeqLM(Model):
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
new_decoder_input_length = decoder_input_length + 1
# Generated token
......
......@@ -13,6 +13,7 @@ from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
......@@ -100,7 +101,12 @@ def serve(
logger.exception("Error when initializing model")
raise
server = aio.server(interceptors=[ExceptionInterceptor()])
server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server
)
......
import grpc
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.grpc._aio_server import (
OpenTelemetryAioServerInterceptor,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
)
class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
def __init__(self):
super().__init__(trace.get_tracer(__name__))
def _start_span(self, handler_call_details, context, set_status_on_exception=False):
"""
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
"""
# standard attributes
attributes = {
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
}
# if we have details about the call, split into service and method
if handler_call_details.method:
service, method = handler_call_details.method.lstrip("/").split("/", 1)
attributes.update(
{
SpanAttributes.RPC_METHOD: method,
SpanAttributes.RPC_SERVICE: service,
}
)
# add some attributes from the metadata
metadata = dict(context.invocation_metadata())
if "user-agent" in metadata:
attributes["rpc.user_agent"] = metadata["user-agent"]
# We use gRPC over a UNIX socket
attributes.update({SpanAttributes.NET_TRANSPORT: "unix"})
return self._tracer.start_as_current_span(
name=handler_call_details.method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
set_status_on_exception=set_status_on_exception,
)
def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
trace.get_tracer_provider().add_span_processor(span_processor)
......@@ -36,16 +36,14 @@ class Sampling:
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
return logits.argmax()
class NextTokenChooser:
......@@ -87,8 +85,9 @@ class NextTokenChooser:
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_ids = self.choice(scores)
return next_ids, logprobs
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
......@@ -163,6 +162,7 @@ def initialize_torch_distributed():
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
......@@ -181,7 +181,7 @@ def initialize_torch_distributed():
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment