Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
fcc2c5fc
Unverified
Commit
fcc2c5fc
authored
Jan 05, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 05, 2023
Browse files
feat(launcher): Log server stdout (#19)
Co-authored-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
b94f3021
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
512 additions
and
309 deletions
+512
-309
launcher/Cargo.toml
launcher/Cargo.toml
+1
-1
launcher/src/main.rs
launcher/src/main.rs
+22
-0
launcher/tests/integration_tests.rs
launcher/tests/integration_tests.rs
+2
-6
server/poetry.lock
server/poetry.lock
+432
-296
server/pyproject.toml
server/pyproject.toml
+5
-2
server/text_generation/cli.py
server/text_generation/cli.py
+15
-0
server/text_generation/interceptor.py
server/text_generation/interceptor.py
+29
-0
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+1
-1
server/text_generation/server.py
server/text_generation/server.py
+5
-3
No files found.
launcher/Cargo.toml
View file @
fcc2c5fc
...
...
@@ -8,6 +8,7 @@ description = "Text Generation Launcher"
[dependencies]
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
ctrlc
=
{
version
=
"3.2.3"
,
features
=
["termination"]
}
serde_json
=
"1.0.89"
subprocess
=
"0.2.9"
tracing
=
"0.1.37"
tracing-subscriber
=
{
version
=
"0.3.16"
,
features
=
["json"]
}
...
...
@@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
float_eq
=
"1.0.1"
reqwest
=
{
version
=
"0.11.13"
,
features
=
[
"blocking"
,
"json"
]
}
serde
=
"1.0.150"
serde_json
=
"1.0.89"
launcher/src/main.rs
View file @
fcc2c5fc
...
...
@@ -11,6 +11,7 @@ use std::thread;
use
std
::
thread
::
sleep
;
use
std
::
time
::{
Duration
,
Instant
};
use
std
::{
fs
,
io
};
use
serde_json
::
Value
;
use
subprocess
::{
Popen
,
PopenConfig
,
PopenError
,
Redirection
};
/// App Configuration
...
...
@@ -274,6 +275,9 @@ fn shard_manager(
model_name
,
"--uds-path"
.to_string
(),
uds_path
,
"--logger-level"
.to_string
(),
"ERROR"
.to_string
(),
"--json-output"
.to_string
(),
];
if
world_size
>
1
{
...
...
@@ -347,6 +351,24 @@ fn shard_manager(
}
};
// Redirect STDOUT to the console
let
shard_stdout
=
p
.stdout
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
// Enter shard-manager tracing span
let
stdout
=
BufReader
::
new
(
shard_stdout
);
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"shard-manager"
,
rank
=
rank
)
.entered
();
for
line
in
stdout
.lines
()
{
// Parse loguru logs
if
let
Ok
(
value
)
=
serde_json
::
from_str
::
<
Value
>
(
&
line
.unwrap
())
{
if
let
Some
(
text
)
=
value
.get
(
"text"
)
{
// Format escaped newlines
tracing
::
error!
(
"{}"
,
text
.to_string
()
.replace
(
"
\\
n"
,
"
\n
"
));
}
}
}
});
let
mut
ready
=
false
;
let
start_time
=
Instant
::
now
();
let
mut
wait_time
=
Instant
::
now
();
...
...
launcher/tests/integration_tests.rs
View file @
fcc2c5fc
...
...
@@ -41,25 +41,21 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
&
argv
,
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pip
e
,
stderr
:
Redirection
::
Merg
e
,
..
Default
::
default
()
},
)
.expect
(
"Could not start launcher"
);
// Redirect STDOUT and STDERR to the console
// (STDERR is merged into STDOUT)
let
launcher_stdout
=
launcher
.stdout
.take
()
.unwrap
();
let
launcher_stderr
=
launcher
.stderr
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
let
stdout
=
BufReader
::
new
(
launcher_stdout
);
let
stderr
=
BufReader
::
new
(
launcher_stderr
);
for
line
in
stdout
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
}
for
line
in
stderr
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
}
});
for
_
in
0
..
60
{
...
...
server/poetry.lock
View file @
fcc2c5fc
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
fcc2c5fc
...
...
@@ -10,12 +10,15 @@ text-generation-server = 'text_generation.cli:app'
[tool.poetry.dependencies]
python
=
"^3.9"
protobuf
=
"^4.21.7"
grpcio
=
"^1.49.1"
grpcio
=
"^1.51.1"
grpcio-status
=
"^1.51.1"
grpcio-reflection
=
"^1.51.1"
grpc-interceptor
=
"^0.15.0"
typer
=
"^0.6.1"
grpcio-reflection
=
"^1.49.1"
accelerate
=
"^0.12.0"
bitsandbytes
=
"^0.35.1"
safetensors
=
"^0.2.4"
loguru
=
"^0.6.0"
[tool.poetry.extras]
bnb
=
["bitsandbytes"]
...
...
server/text_generation/cli.py
View file @
fcc2c5fc
import
os
import
sys
import
typer
from
pathlib
import
Path
from
loguru
import
logger
from
text_generation
import
server
,
utils
...
...
@@ -14,7 +16,20 @@ def serve(
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
uds_path
:
Path
=
"/tmp/text-generation"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
):
# Remove default handler
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{message}"
,
filter
=
"text_generation"
,
level
=
logger_level
,
serialize
=
json_output
,
backtrace
=
True
,
diagnose
=
False
,
)
if
sharded
:
assert
(
os
.
getenv
(
"RANK"
,
None
)
is
not
None
...
...
server/text_generation/interceptor.py
0 → 100644
View file @
fcc2c5fc
import
grpc
from
google.rpc
import
status_pb2
,
code_pb2
from
grpc_status
import
rpc_status
from
grpc_interceptor.server
import
AsyncServerInterceptor
from
loguru
import
logger
from
typing
import
Callable
,
Any
class
ExceptionInterceptor
(
AsyncServerInterceptor
):
async
def
intercept
(
self
,
method
:
Callable
,
request_or_iterator
:
Any
,
context
:
grpc
.
ServicerContext
,
method_name
:
str
,
)
->
Any
:
try
:
response
=
method
(
request_or_iterator
,
context
)
return
await
response
except
Exception
as
err
:
method_name
=
method_name
.
split
(
"/"
)[
-
1
]
logger
.
exception
(
f
"Method
{
method_name
}
encountered an error."
)
await
context
.
abort_with_status
(
rpc_status
.
to_status
(
status_pb2
.
Status
(
code
=
code_pb2
.
INTERNAL
,
message
=
str
(
err
))
)
)
server/text_generation/models/__init__.py
View file @
fcc2c5fc
...
...
@@ -4,7 +4,7 @@ from text_generation.models.bloom import BLOOM, BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
]
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
,
"get_model"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
...
...
server/text_generation/server.py
View file @
fcc2c5fc
...
...
@@ -2,12 +2,14 @@ import asyncio
import
os
from
grpc
import
aio
from
loguru
import
logger
from
grpc_reflection.v1alpha
import
reflection
from
pathlib
import
Path
from
typing
import
List
from
text_generation.cache
import
Cache
from
text_generation.interceptor
import
ExceptionInterceptor
from
text_generation.models
import
Model
,
get_model
from
text_generation.pb
import
generate_pb2_grpc
,
generate_pb2
...
...
@@ -91,7 +93,7 @@ def serve(
model
=
get_model
(
model_name
,
sharded
,
quantize
)
server
=
aio
.
server
()
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()]
)
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
TextGenerationService
(
model
,
Cache
(),
server_urls
),
server
)
...
...
@@ -102,11 +104,11 @@ def serve(
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
server
.
add_insecure_port
(
local_url
)
await
server
.
start
()
print
(
"Server started at {}"
.
format
(
local_url
))
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
try
:
await
server
.
wait_for_termination
()
except
KeyboardInterrupt
:
print
(
"Signal received. Shutting down"
)
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
quantize
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment