Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
9af45414
Unverified
Commit
9af45414
authored
Feb 13, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 13, 2023
Browse files
feat: add distributed tracing (#62)
parent
e520d5b3
Changes
31
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
596 additions
and
64 deletions
+596
-64
router/src/validation.rs
router/src/validation.rs
+18
-5
rust-toolchain.toml
rust-toolchain.toml
+1
-1
server/Makefile
server/Makefile
+1
-1
server/poetry.lock
server/poetry.lock
+461
-33
server/pyproject.toml
server/pyproject.toml
+4
-1
server/text_generation/cli.py
server/text_generation/cli.py
+17
-11
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+8
-2
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+6
-1
server/text_generation/server.py
server/text_generation/server.py
+7
-1
server/text_generation/tracing.py
server/text_generation/tracing.py
+65
-0
server/text_generation/utils.py
server/text_generation/utils.py
+8
-8
No files found.
router/src/validation.rs
View file @
9af45414
...
...
@@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tracing
::{
instrument
,
Span
};
const
MAX_MAX_NEW_TOKENS
:
u32
=
512
;
const
MAX_STOP_SEQUENCES
:
usize
=
4
;
...
...
@@ -36,6 +37,7 @@ impl Validation {
}
/// Validate a payload and get the number of tokens in the input
#[instrument(skip_all)]
pub
(
crate
)
async
fn
validate
(
&
self
,
request
:
GenerateRequest
,
...
...
@@ -44,7 +46,10 @@ impl Validation {
let
(
sender
,
receiver
)
=
oneshot
::
channel
();
// Send request to the background validation task
// Unwrap is safe here
self
.sender
.send
((
request
,
sender
))
.await
.unwrap
();
self
.sender
.send
((
request
,
sender
,
Span
::
current
()))
.await
.unwrap
();
// Await on response channel
// Unwrap is safe here
receiver
.await
.unwrap
()
...
...
@@ -97,10 +102,17 @@ fn validation_worker(
let
mut
rng
=
rand
::
thread_rng
();
// Loop over requests
while
let
Some
((
request
,
response_tx
))
=
receiver
.blocking_recv
()
{
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
,
&
mut
rng
))
.unwrap_or
(())
while
let
Some
((
request
,
response_tx
,
parent_span
))
=
receiver
.blocking_recv
()
{
parent_span
.in_scope
(||
{
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
,
&
mut
rng
)
.map_err
(|
err
|
{
tracing
::
error!
(
"{err}"
);
err
}),
)
.unwrap_or
(())
})
}
}
...
...
@@ -203,6 +215,7 @@ fn validate(
type
ValidationRequest
=
(
GenerateRequest
,
oneshot
::
Sender
<
Result
<
ValidGenerateRequest
,
ValidationError
>>
,
Span
,
);
#[derive(Debug)]
...
...
rust-toolchain.toml
View file @
9af45414
[toolchain]
channel
=
"1.6
5
.0"
channel
=
"1.6
7
.0"
components
=
[
"rustfmt"
,
"clippy"
]
\ No newline at end of file
server/Makefile
View file @
9af45414
gen-server
:
# Compile protos
pip
install
grpcio-tools
==
1.
49
.1
--no-cache-dir
pip
install
grpcio-tools
==
1.
51
.1
--no-cache-dir
mkdir
text_generation/pb
||
true
python
-m
grpc_tools.protoc
-I
../proto
--python_out
=
text_generation/pb
--grpc_python_out
=
text_generation/pb ../proto/generate.proto
find text_generation/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
...
...
server/poetry.lock
View file @
9af45414
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
9af45414
...
...
@@ -19,12 +19,15 @@ accelerate = "^0.15.0"
bitsandbytes
=
"^0.35.1"
safetensors
=
"^0.2.4"
loguru
=
"^0.6.0"
opentelemetry-api
=
"^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
opentelemetry-instrumentation-grpc
=
"^0.36b0"
[tool.poetry.extras]
bnb
=
["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
grpcio-tools
=
"^1.
49
.1"
grpcio-tools
=
"^1.
51
.1"
pytest
=
"^7.2.0"
[build-system]
...
...
server/text_generation/cli.py
View file @
9af45414
...
...
@@ -7,6 +7,7 @@ from loguru import logger
from
typing
import
Optional
from
text_generation
import
server
,
utils
from
text_generation.tracing
import
setup_tracing
app
=
typer
.
Typer
()
...
...
@@ -20,18 +21,8 @@ def serve(
uds_path
:
Path
=
"/tmp/text-generation"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
otlp_endpoint
:
Optional
[
str
]
=
None
,
):
# Remove default handler
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{message}"
,
filter
=
"text_generation"
,
level
=
logger_level
,
serialize
=
json_output
,
backtrace
=
True
,
diagnose
=
False
,
)
if
sharded
:
assert
(
os
.
getenv
(
"RANK"
,
None
)
is
not
None
...
...
@@ -46,6 +37,21 @@ def serve(
os
.
getenv
(
"MASTER_PORT"
,
None
)
is
not
None
),
"MASTER_PORT must be set when sharded is True"
# Remove default handler
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{message}"
,
filter
=
"text_generation"
,
level
=
logger_level
,
serialize
=
json_output
,
backtrace
=
True
,
diagnose
=
False
,
)
# Setup OpenTelemetry distributed tracing
if
otlp_endpoint
is
not
None
:
setup_tracing
(
shard
=
os
.
getenv
(
"RANK"
,
0
),
otlp_endpoint
=
otlp_endpoint
)
server
.
serve
(
model_id
,
revision
,
sharded
,
quantize
,
uds_path
)
...
...
server/text_generation/models/causal_lm.py
View file @
9af45414
import
torch
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
...
...
@@ -9,6 +10,8 @@ from text_generation.models.types import Batch, PrefillTokens, Generation, Gener
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
tracer
=
trace
.
get_tracer
(
__name__
)
@
dataclass
class
CausalLMBatch
(
Batch
):
...
...
@@ -94,6 +97,7 @@ class CausalLMBatch(Batch):
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
...
...
@@ -286,6 +290,7 @@ class CausalLM(Model):
)
return
outputs
.
logits
,
outputs
.
past_key_values
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
]]:
...
...
@@ -331,8 +336,9 @@ class CausalLM(Model):
all_input_ids
,
)
in
enumerate
(
iterator
):
# Select next token
tokens
,
logprobs
=
next_token_chooser
(
all_input_ids
.
view
(
1
,
-
1
),
logits
)
next_token_id
=
tokens
[
-
1
].
view
(
1
,
1
)
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids
.
view
(
1
,
-
1
),
logits
)
# Append next token to all tokens
all_input_ids
=
torch
.
cat
([
all_input_ids
,
next_token_id
])
...
...
server/text_generation/models/seq2seq_lm.py
View file @
9af45414
import
torch
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
,
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
...
...
@@ -9,6 +10,8 @@ from text_generation.models.types import GeneratedText, Batch, Generation, Prefi
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
tracer
=
trace
.
get_tracer
(
__name__
)
@
dataclass
class
Seq2SeqLMBatch
(
Batch
):
...
...
@@ -107,6 +110,7 @@ class Seq2SeqLMBatch(Batch):
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"Seq2SeqLMBatch"
])
->
"Seq2SeqLMBatch"
:
"""Concatenate multiple batches together by padding internal torch tensors"""
...
...
@@ -361,6 +365,7 @@ class Seq2SeqLM(Model):
outputs
.
past_key_values
,
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
]]:
...
...
@@ -418,7 +423,7 @@ class Seq2SeqLM(Model):
)
# Append next token to decoder tokens
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token_id
])
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token_id
.
squeeze
(
1
)
])
new_decoder_input_length
=
decoder_input_length
+
1
# Generated token
...
...
server/text_generation/server.py
View file @
9af45414
...
...
@@ -13,6 +13,7 @@ from text_generation.cache import Cache
from
text_generation.interceptor
import
ExceptionInterceptor
from
text_generation.models
import
Model
,
get_model
from
text_generation.pb
import
generate_pb2_grpc
,
generate_pb2
from
text_generation.tracing
import
UDSOpenTelemetryAioServerInterceptor
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
...
...
@@ -100,7 +101,12 @@ def serve(
logger
.
exception
(
"Error when initializing model"
)
raise
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()])
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
(),
UDSOpenTelemetryAioServerInterceptor
(),
]
)
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
TextGenerationService
(
model
,
Cache
(),
server_urls
),
server
)
...
...
server/text_generation/tracing.py
0 → 100644
View file @
9af45414
import
grpc
from
opentelemetry
import
trace
from
opentelemetry.exporter.otlp.proto.grpc.trace_exporter
import
OTLPSpanExporter
from
opentelemetry.instrumentation.grpc._aio_server
import
(
OpenTelemetryAioServerInterceptor
,
)
from
opentelemetry.semconv.trace
import
SpanAttributes
from
opentelemetry.sdk.resources
import
Resource
from
opentelemetry.sdk.trace
import
TracerProvider
from
opentelemetry.sdk.trace.export
import
(
BatchSpanProcessor
,
)
class
UDSOpenTelemetryAioServerInterceptor
(
OpenTelemetryAioServerInterceptor
):
def
__init__
(
self
):
super
().
__init__
(
trace
.
get_tracer
(
__name__
))
def
_start_span
(
self
,
handler_call_details
,
context
,
set_status_on_exception
=
False
):
"""
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
"""
# standard attributes
attributes
=
{
SpanAttributes
.
RPC_SYSTEM
:
"grpc"
,
SpanAttributes
.
RPC_GRPC_STATUS_CODE
:
grpc
.
StatusCode
.
OK
.
value
[
0
],
}
# if we have details about the call, split into service and method
if
handler_call_details
.
method
:
service
,
method
=
handler_call_details
.
method
.
lstrip
(
"/"
).
split
(
"/"
,
1
)
attributes
.
update
(
{
SpanAttributes
.
RPC_METHOD
:
method
,
SpanAttributes
.
RPC_SERVICE
:
service
,
}
)
# add some attributes from the metadata
metadata
=
dict
(
context
.
invocation_metadata
())
if
"user-agent"
in
metadata
:
attributes
[
"rpc.user_agent"
]
=
metadata
[
"user-agent"
]
# We use gRPC over a UNIX socket
attributes
.
update
({
SpanAttributes
.
NET_TRANSPORT
:
"unix"
})
return
self
.
_tracer
.
start_as_current_span
(
name
=
handler_call_details
.
method
,
kind
=
trace
.
SpanKind
.
SERVER
,
attributes
=
attributes
,
set_status_on_exception
=
set_status_on_exception
,
)
def
setup_tracing
(
shard
:
int
,
otlp_endpoint
:
str
):
resource
=
Resource
.
create
(
attributes
=
{
"service.name"
:
f
"text-generation-inference.server-
{
shard
}
"
}
)
span_exporter
=
OTLPSpanExporter
(
endpoint
=
otlp_endpoint
,
insecure
=
True
)
span_processor
=
BatchSpanProcessor
(
span_exporter
)
trace
.
set_tracer_provider
(
TracerProvider
(
resource
=
resource
))
trace
.
get_tracer_provider
().
add_span_processor
(
span_processor
)
server/text_generation/utils.py
View file @
9af45414
...
...
@@ -36,16 +36,14 @@ class Sampling:
self
.
seed
=
seed
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
).
squeeze
(
1
)
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
return
next_tokens
class
Greedy
:
def
__call__
(
self
,
logits
):
return
logits
.
argmax
(
dim
=-
1
)
return
logits
.
argmax
()
class
NextTokenChooser
:
...
...
@@ -87,8 +85,9 @@ class NextTokenChooser:
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
# Choose tokens
next_ids
=
self
.
choice
(
scores
)
return
next_ids
,
logprobs
next_id
=
self
.
choice
(
scores
[
-
1
])
return
next_id
.
view
(
1
,
1
),
logprobs
@
classmethod
def
from_pb
(
...
...
@@ -163,6 +162,7 @@ def initialize_torch_distributed():
if
torch
.
cuda
.
is_available
():
from
torch.distributed
import
ProcessGroupNCCL
# Set the device id.
assert
world_size
<=
torch
.
cuda
.
device_count
(),
"Each process is one gpu"
device
=
rank
%
torch
.
cuda
.
device_count
()
...
...
@@ -181,7 +181,7 @@ def initialize_torch_distributed():
world_size
=
world_size
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
pg_options
=
options
pg_options
=
options
,
)
return
torch
.
distributed
.
group
.
WORLD
,
rank
,
world_size
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment