Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4acc42a6
Unverified
Commit
4acc42a6
authored
Feb 07, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 07, 2023
Browse files
fix(server): better handling of inference mode (#57)
parent
e114d874
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
25 deletions
+35
-25
launcher/src/main.rs
launcher/src/main.rs
+2
-2
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+3
-0
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+5
-10
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+7
-12
server/text_generation/server.py
server/text_generation/server.py
+13
-1
server/text_generation/utils.py
server/text_generation/utils.py
+5
-0
No files found.
launcher/src/main.rs
View file @
4acc42a6
...
@@ -38,9 +38,9 @@ struct Args {
...
@@ -38,9 +38,9 @@ struct Args {
port
:
u16
,
port
:
u16
,
#[clap(default_value
=
"/tmp/text-generation-server"
,
long,
env)]
#[clap(default_value
=
"/tmp/text-generation-server"
,
long,
env)]
shard_uds_path
:
String
,
shard_uds_path
:
String
,
#[clap(default_value
=
"
localhost
"
,
long,
env)]
#[clap(default_value
=
"
0.0.0.0
"
,
long,
env)]
master_addr
:
String
,
master_addr
:
String
,
#[clap(default_value
=
"
295
00"
,
long,
env)]
#[clap(default_value
=
"
60
00"
,
long,
env)]
master_port
:
usize
,
master_port
:
usize
,
#[clap(long,
env)]
#[clap(long,
env)]
json_output
:
bool
,
json_output
:
bool
,
...
...
server/text_generation/models/__init__.py
View file @
4acc42a6
...
@@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
...
@@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Disable gradients
torch
.
set_grad_enabled
(
False
)
def
get_model
(
def
get_model
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
...
...
server/text_generation/models/causal_lm.py
View file @
4acc42a6
...
@@ -289,17 +289,12 @@ class CausalLM(Model):
...
@@ -289,17 +289,12 @@ class CausalLM(Model):
def
generate_token
(
def
generate_token
(
self
,
batch
:
CausalLMBatch
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
logits
,
past
=
self
.
forward
(
context_manager
=
(
batch
.
input_ids
,
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
batch
.
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
,
)
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
,
)
# List of indices to cache
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_keep_indices
=
[]
...
...
server/text_generation/models/seq2seq_lm.py
View file @
4acc42a6
...
@@ -364,19 +364,14 @@ class Seq2SeqLM(Model):
...
@@ -364,19 +364,14 @@ class Seq2SeqLM(Model):
def
generate_token
(
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
logits
,
encoder_last_hidden_state
,
past
=
self
.
forward
(
context_manager
=
(
batch
.
input_ids
,
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
batch
.
attention_mask
,
batch
.
decoder_input_ids
,
batch
.
decoder_attention_mask
,
batch
.
encoder_last_hidden_state
,
batch
.
past_key_values
,
)
)
with
context_manager
():
logits
,
encoder_last_hidden_state
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
decoder_input_ids
,
batch
.
decoder_attention_mask
,
batch
.
encoder_last_hidden_state
,
batch
.
past_key_values
,
)
# List of indices to cache
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_keep_indices
=
[]
...
...
server/text_generation/server.py
View file @
4acc42a6
import
asyncio
import
asyncio
import
os
import
os
import
torch
from
grpc
import
aio
from
grpc
import
aio
from
loguru
import
logger
from
loguru
import
logger
...
@@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self
.
cache
=
cache
self
.
cache
=
cache
self
.
model
=
model
self
.
model
=
model
self
.
server_urls
=
server_urls
self
.
server_urls
=
server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
if
model
.
device
.
type
==
"cuda"
:
# Force inference mode for the lifetime of TextGenerationService
self
.
_inference_mode_raii_guard
=
torch
.
_C
.
_InferenceMode
(
True
)
async
def
ServiceDiscovery
(
self
,
request
,
context
):
async
def
ServiceDiscovery
(
self
,
request
,
context
):
return
generate_pb2
.
ServiceDiscoveryResponse
(
urls
=
self
.
server_urls
)
return
generate_pb2
.
ServiceDiscoveryResponse
(
urls
=
self
.
server_urls
)
...
@@ -89,7 +94,11 @@ def serve(
...
@@ -89,7 +94,11 @@ def serve(
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
server_urls
=
[
local_url
]
server_urls
=
[
local_url
]
model
=
get_model
(
model_id
,
revision
,
sharded
,
quantize
)
try
:
model
=
get_model
(
model_id
,
revision
,
sharded
,
quantize
)
except
Exception
:
logger
.
exception
(
"Error when initializing model"
)
raise
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()])
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()])
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
...
@@ -101,8 +110,11 @@ def serve(
...
@@ -101,8 +110,11 @@ def serve(
)
)
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
reflection
.
enable_server_reflection
(
SERVICE_NAMES
,
server
)
server
.
add_insecure_port
(
local_url
)
server
.
add_insecure_port
(
local_url
)
await
server
.
start
()
await
server
.
start
()
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
try
:
try
:
await
server
.
wait_for_termination
()
await
server
.
wait_for_termination
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
...
...
server/text_generation/utils.py
View file @
4acc42a6
...
@@ -171,9 +171,14 @@ def initialize_torch_distributed():
...
@@ -171,9 +171,14 @@ def initialize_torch_distributed():
else
:
else
:
backend
=
"gloo"
backend
=
"gloo"
master_ip
=
os
.
getenv
(
"MASTER_ADDR"
,
"0.0.0.0"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
"6000"
)
init_method
=
f
"tcp://
{
master_ip
}
:
{
master_port
}
"
# Call the init process.
# Call the init process.
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
backend
=
backend
,
init_method
=
init_method
,
world_size
=
world_size
,
world_size
=
world_size
,
rank
=
rank
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
timeout
=
timedelta
(
seconds
=
60
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment