Commit d5fe59fe authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'stop_tokens' into 'main'

Adding several things to the text_generation_server that were necessary for the demos

See merge request ADLR/megatron-lm!350
parents 10c6ad06 1d391bba
<!-- coding=utf-8-->
<!-- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.-->
<!---->
<!-- Licensed under the Apache License, Version 2.0 (the "License");-->
<!-- you may not use this file except in compliance with the License.-->
<!-- You may obtain a copy of the License at-->
<!---->
<!-- http://www.apache.org/licenses/LICENSE-2.0-->
<!---->
<!-- Unless required by applicable law or agreed to in writing, software-->
<!-- distributed under the License is distributed on an "AS IS" BASIS,-->
<!-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->
<!-- See the License for the specific language governing permissions and-->
<!-- limitations under the License.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
......@@ -35,7 +35,10 @@ def generate_and_post_process(model,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
......@@ -49,7 +52,10 @@ def generate_and_post_process(model,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
......@@ -74,7 +80,10 @@ def generate(model,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
......@@ -87,8 +96,11 @@ def generate(model,
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(7, float_list=values)
temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol,
stop_on_eol,
random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
......@@ -96,6 +108,12 @@ def generate(model,
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
......@@ -117,4 +135,6 @@ def generate(model,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
......@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs=False,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
......@@ -131,6 +134,10 @@ def generate_tokens_probs_and_return_on_first_stage(
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
......@@ -227,8 +234,20 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
......
......@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def put(self):
args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json():
return "prompts argument required", 400
......@@ -102,8 +99,42 @@ class MegatronGenerate(Resource):
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
......@@ -114,7 +145,13 @@ class MegatronGenerate(Resource):
top_p_sampling=top_p,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True)
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
......
......@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
try:
generate_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment