Unverified Commit 9af45414 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add distributed tracing (#62)

parent e520d5b3
...@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet ...@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512; const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4; const MAX_STOP_SEQUENCES: usize = 4;
...@@ -36,6 +37,7 @@ impl Validation { ...@@ -36,6 +37,7 @@ impl Validation {
} }
/// Validate a payload and get the number of tokens in the input /// Validate a payload and get the number of tokens in the input
#[instrument(skip_all)]
pub(crate) async fn validate( pub(crate) async fn validate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
...@@ -44,7 +46,10 @@ impl Validation { ...@@ -44,7 +46,10 @@ impl Validation {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
// Send request to the background validation task // Send request to the background validation task
// Unwrap is safe here // Unwrap is safe here
self.sender.send((request, sender)).await.unwrap(); self.sender
.send((request, sender, Span::current()))
.await
.unwrap();
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
...@@ -97,10 +102,17 @@ fn validation_worker( ...@@ -97,10 +102,17 @@ fn validation_worker(
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
// Loop over requests // Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() { while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx response_tx
.send(validate(request, &tokenizer, max_input_length, &mut rng)) .send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
tracing::error!("{err}");
err
}),
)
.unwrap_or(()) .unwrap_or(())
})
} }
} }
...@@ -203,6 +215,7 @@ fn validate( ...@@ -203,6 +215,7 @@ fn validate(
type ValidationRequest = ( type ValidationRequest = (
GenerateRequest, GenerateRequest,
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>, oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
Span,
); );
#[derive(Debug)] #[derive(Debug)]
......
[toolchain] [toolchain]
channel = "1.65.0" channel = "1.67.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]
\ No newline at end of file
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.49.1 --no-cache-dir pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
......
This diff is collapsed.
...@@ -19,12 +19,15 @@ accelerate = "^0.15.0" ...@@ -19,12 +19,15 @@ accelerate = "^0.15.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.35.1"
safetensors = "^0.2.4" safetensors = "^0.2.4"
loguru = "^0.6.0" loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
[tool.poetry.extras] [tool.poetry.extras]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1" grpcio-tools = "^1.51.1"
pytest = "^7.2.0" pytest = "^7.2.0"
[build-system] [build-system]
......
...@@ -7,6 +7,7 @@ from loguru import logger ...@@ -7,6 +7,7 @@ from loguru import logger
from typing import Optional from typing import Optional
from text_generation import server, utils from text_generation import server, utils
from text_generation.tracing import setup_tracing
app = typer.Typer() app = typer.Typer()
...@@ -20,18 +21,8 @@ def serve( ...@@ -20,18 +21,8 @@ def serve(
uds_path: Path = "/tmp/text-generation", uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None,
): ):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if sharded: if sharded:
assert ( assert (
os.getenv("RANK", None) is not None os.getenv("RANK", None) is not None
...@@ -46,6 +37,21 @@ def serve( ...@@ -46,6 +37,21 @@ def serve(
os.getenv("MASTER_PORT", None) is not None os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True" ), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path) server.serve(model_id, revision, sharded, quantize, uds_path)
......
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
...@@ -9,6 +10,8 @@ from text_generation.models.types import Batch, PrefillTokens, Generation, Gener ...@@ -9,6 +10,8 @@ from text_generation.models.types import Batch, PrefillTokens, Generation, Gener
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class CausalLMBatch(Batch): class CausalLMBatch(Batch):
...@@ -94,6 +97,7 @@ class CausalLMBatch(Batch): ...@@ -94,6 +97,7 @@ class CausalLMBatch(Batch):
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = sum(batch.size for batch in batches)
...@@ -286,6 +290,7 @@ class CausalLM(Model): ...@@ -286,6 +290,7 @@ class CausalLM(Model):
) )
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
...@@ -331,8 +336,9 @@ class CausalLM(Model): ...@@ -331,8 +336,9 @@ class CausalLM(Model):
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits) next_token_id, logprobs = next_token_chooser(
next_token_id = tokens[-1].view(1, 1) all_input_ids.view(1, -1), logits
)
# Append next token to all tokens # Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id]) all_input_ids = torch.cat([all_input_ids, next_token_id])
......
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
...@@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi ...@@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class Seq2SeqLMBatch(Batch): class Seq2SeqLMBatch(Batch):
...@@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors""" """Concatenate multiple batches together by padding internal torch tensors"""
...@@ -361,6 +365,7 @@ class Seq2SeqLM(Model): ...@@ -361,6 +365,7 @@ class Seq2SeqLM(Model):
outputs.past_key_values, outputs.past_key_values,
) )
@tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
...@@ -418,7 +423,7 @@ class Seq2SeqLM(Model): ...@@ -418,7 +423,7 @@ class Seq2SeqLM(Model):
) )
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
# Generated token # Generated token
......
...@@ -13,6 +13,7 @@ from text_generation.cache import Cache ...@@ -13,6 +13,7 @@ from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...@@ -100,7 +101,12 @@ def serve( ...@@ -100,7 +101,12 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
server = aio.server(interceptors=[ExceptionInterceptor()]) server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server TextGenerationService(model, Cache(), server_urls), server
) )
......
import grpc
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.grpc._aio_server import (
OpenTelemetryAioServerInterceptor,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
)
class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
def __init__(self):
super().__init__(trace.get_tracer(__name__))
def _start_span(self, handler_call_details, context, set_status_on_exception=False):
"""
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
"""
# standard attributes
attributes = {
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
}
# if we have details about the call, split into service and method
if handler_call_details.method:
service, method = handler_call_details.method.lstrip("/").split("/", 1)
attributes.update(
{
SpanAttributes.RPC_METHOD: method,
SpanAttributes.RPC_SERVICE: service,
}
)
# add some attributes from the metadata
metadata = dict(context.invocation_metadata())
if "user-agent" in metadata:
attributes["rpc.user_agent"] = metadata["user-agent"]
# We use gRPC over a UNIX socket
attributes.update({SpanAttributes.NET_TRANSPORT: "unix"})
return self._tracer.start_as_current_span(
name=handler_call_details.method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
set_status_on_exception=set_status_on_exception,
)
def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
trace.get_tracer_provider().add_span_processor(span_processor)
...@@ -36,16 +36,14 @@ class Sampling: ...@@ -36,16 +36,14 @@ class Sampling:
self.seed = seed self.seed = seed
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial( next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens return next_tokens
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax()
class NextTokenChooser: class NextTokenChooser:
...@@ -87,8 +85,9 @@ class NextTokenChooser: ...@@ -87,8 +85,9 @@ class NextTokenChooser:
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
# Choose tokens # Choose tokens
next_ids = self.choice(scores) next_id = self.choice(scores[-1])
return next_ids, logprobs
return next_id.view(1, 1), logprobs
@classmethod @classmethod
def from_pb( def from_pb(
...@@ -163,6 +162,7 @@ def initialize_torch_distributed(): ...@@ -163,6 +162,7 @@ def initialize_torch_distributed():
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
# Set the device id. # Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu" assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count() device = rank % torch.cuda.device_count()
...@@ -181,7 +181,7 @@ def initialize_torch_distributed(): ...@@ -181,7 +181,7 @@ def initialize_torch_distributed():
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
pg_options=options pg_options=options,
) )
return torch.distributed.group.WORLD, rank, world_size return torch.distributed.group.WORLD, rank, world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment