Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a5bfc296
Commit
a5bfc296
authored
Oct 10, 2021
by
mshoeybi
Browse files
added new inference to the server
parent
ed6d28b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
413 deletions
+18
-413
megatron/inference/api.py
megatron/inference/api.py
+1
-1
megatron/text_generation_server.py
megatron/text_generation_server.py
+15
-9
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+0
-401
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+2
-2
No files found.
megatron/inference/api.py
View file @
a5bfc296
...
...
@@ -64,7 +64,7 @@ def generate_and_post_process(model,
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
.
cpu
().
numpy
()
#
.tolist()
all_log_probs
=
all_log_probs
.
cpu
().
numpy
().
tolist
()
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
output_log_probs
,
all_log_probs
,
tokens
...
...
megatron/text_generation_server.py
View file @
a5bfc296
...
...
@@ -20,7 +20,8 @@ from flask import Flask, request, jsonify, current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.text_generation_utils
import
generate
from
megatron.inference.api
import
generate_and_post_process
GENERATE_NUM
=
0
lock
=
threading
.
Lock
()
...
...
@@ -99,14 +100,19 @@ class MegatronGenerate(Resource):
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
=
generate
(
self
.
model
,
prompts
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
,
add_BOS
)
response
,
response_seg
,
response_logprobs
,
_
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
return_all_log_probs
=
False
,
greedy_sampling
=
args
.
greedy
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
...
...
megatron/text_generation_utils.py
deleted
100644 → 0
View file @
ed6d28b1
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import
copy
import
json
import
os
import
time
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.p2p_communication
import
recv_forward
,
send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Move to GPU.
tokens
=
context_tokens
.
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
def
top_k_logits
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if
top_k
>
0
:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove
[...,
1
:]
\
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
return
logits
def
pad_batch
(
batch
,
pad_id
,
max_len
):
context_lengths
=
[]
max_context_length
=
max
([
len
(
tokens
)
for
tokens
in
batch
])
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
max_context_length
+
max_len
:
tokens
.
extend
([
pad_id
]
*
(
max_context_length
+
max_len
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
tokenize_batch
(
sentences
,
max_len
,
add_BOS
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
if
add_BOS
:
context_tokens
=
[[
tokenizer
.
eod
]
+
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
else
:
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
max_len
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
]
input_info_tensor
=
torch
.
cuda
.
FloatTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
def
receive_generate_info
():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
7
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
int
(
input_info_tensor
[
0
].
item
())
seq_len
=
int
(
input_info_tensor
[
1
].
item
())
tokens_to_generate
=
int
(
input_info_tensor
[
2
].
item
())
logprobs
=
bool
(
input_info_tensor
[
3
].
item
())
temperature
=
float
(
input_info_tensor
[
4
].
item
())
top_k
=
int
(
input_info_tensor
[
5
].
item
())
top_p
=
float
(
input_info_tensor
[
6
].
item
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
):
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
)
for
tokens
,
lengths
,
output_logits
in
batch_token_iterator
:
context_length
+=
1
if
logprobs
:
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
logprobs
=
False
,
temperature
=
1.0
,
top_k
=
0
,
top_p
=
0.0
,
add_BOS
=
False
):
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
)
else
:
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
=
receive_generate_info
()
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
)
if
output
is
not
None
:
decode_tokens
,
output_logits
=
output
args
=
get_args
()
tokenizer
=
get_tokenizer
()
resp_sentences
=
[]
resp_sentences_seg
=
[]
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
for
i
,
decode_token
in
enumerate
(
decode_tokens
):
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
words
=
[]
for
token
in
decode_token
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
resp_sentences_seg
.
append
(
words
)
if
logprobs
:
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
return
resp_sentences
,
resp_sentences_seg
,
output_logits
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
#assert False, "Implementation untested"
args
=
get_args
()
args
.
eos_id
=
eos_token_id
raw_text_len
=
len
(
context
)
resp_sentences
=
generate
(
model
,
[
context
],
max_gen_length
)
if
resp_sentences
:
return
resp_sentences
[
0
][
raw_text_len
:]
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
input_tensor
=
recv_forward
()
# Forward pass through the model.
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
send_forward
(
output_tensor
)
args
.
seq_length
=
orig_seq_length
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
,
type_ids
=
None
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if
hasattr
(
args
,
'eos_id'
):
eos_id
=
args
.
eos_id
else
:
eos_id
=
tokenizer
.
eod
counter
=
0
batch_size
=
context_tokens
.
size
(
0
)
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
tokens
=
context_tokens
output_logits
=
None
# Generate enough tokens for the longest sequence
maxlen
=
tokens_to_generate
+
context_lengths
.
max
().
item
()
if
maxlen
>
args
.
seq_length
:
maxlen
=
args
.
seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<
maxlen
:
types2use
=
None
if
counter
==
0
:
# Allocate memory for the entire context.
set_inference_key_value_memory
=
True
tokens2use
=
tokens
[:,
:
context_length
]
positions2use
=
position_ids
[:,
:
context_length
]
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
attention_mask2use
=
attention_mask
[...,
:
context_length
,
:
context_length
]
else
:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory
=
False
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
attention_mask2use
=
attention_mask
[...,
(
context_length
-
1
):
context_length
,
:
context_length
]
output
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask2use
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
maxlen
,
tokentype_ids
=
types2use
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
output
=
output
.
float
()
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
logits
=
logits
.
float
()
logits
/=
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
# Clamp the out of vocabulary tokens.
tokenizer
=
get_tokenizer
()
prev
=
torch
.
clamp
(
prev
,
max
=
tokenizer
.
vocab_size
-
1
)
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
if
logprobs
:
if
output_logits
is
None
:
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:
context_length
+
1
],
2
)
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
else
:
output_context
=
F
.
log_softmax
(
output
,
2
)
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
new_output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
,
output_logits
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
,
None
else
:
yield
None
,
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
context_length
+=
1
counter
+=
1
if
done
:
break
tools/run_text_generation_server.py
View file @
a5bfc296
...
...
@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.
text_generation_utils
import
generate
from
megatron.
inference.api
import
generate_and_post_process
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
...
@@ -80,4 +80,4 @@ if __name__ == "__main__":
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
generate
(
model
)
generate
_and_post_process
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment