Commit d46aa964 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'api_change' into 'main'

API improvements.

See merge request ADLR/megatron-lm!337
parents b31e1296 0694205c
......@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json():
return "prompts argument required", 400
sentences = request.get_json()["sentences"]
if len(sentences) > 128:
return "Maximum number of sentences is 128", 400
if "max_len" in request.get_json():
return "max_len is no longer used. Replace with tokens_to_generate", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
......@@ -52,11 +61,11 @@ class MegatronGenerate(Resource):
if tokens_to_generate < 1:
return "tokens_to_generate must be an integer greater than 0"
all_probs = False
if "all_probs" in request.get_json():
all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
temperature = args.temperature
if "temperature" in request.get_json():
......@@ -66,6 +75,22 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
......@@ -74,24 +99,24 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS)
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits,
"all_logits": full_logits,
"tokens": tokens})
response, response_seg, response_logprobs = generate(self.model,
prompts,
tokens_to_generate,
logprobs,
temperature,
top_k,
top_p,
add_BOS)
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits})
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url):
self.app.run(url, threaded=True, debug=False)
......@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature]
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p]
input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
......@@ -125,13 +125,15 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device())
input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
all_probs = int(input_info_tensor[3].item())
logprobs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
......@@ -140,56 +142,53 @@ def receive_generate_info():
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature
return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
batch_token_iterator = sample_sequence_batch(model,
context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids,
attention_mask,
position_ids,
tokens_to_generate,
all_probs,
temperature=temperature)
for tokens, lengths, output_logits, full_logits in batch_token_iterator:
logprobs,
temperature,
top_k,
top_p)
for tokens, lengths, output_logits in batch_token_iterator:
context_length += 1
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group)
else:
if mpu.is_pipeline_first_stage():
if logprobs:
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
args = get_args()
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group)
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
return tokens[:, :context_length], output_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info()
context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
if output is not None:
decode_tokens, output_logits, full_logits = output
decode_tokens, output_logits = output
args = get_args()
tokenizer = get_tokenizer()
......@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
resp_sentences_seg = []
decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
for i, decode_token in enumerate(decode_tokens):
resp_sentences.append(tokenizer.detokenize(decode_token))
words = []
for token in decode_token:
......@@ -205,12 +205,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
words.append(word)
resp_sentences_seg.append(words)
output_logits = output_logits.cpu().numpy().tolist()
if all_probs:
full_logits = full_logits.cpu().numpy().tolist()
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
if logprobs:
output_logits = output_logits.cpu().numpy().tolist()
return resp_sentences, resp_sentences_seg, output_logits
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
......@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
def sample_sequence_batch(model,
context_tokens,
context_lengths,
attention_mask,
position_ids,
tokens_to_generate,
logprobs,
temperature,
top_k,
top_p,
type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
......@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
else:
logits = logits.float()
logits /= temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
logits = top_k_logits(logits, top_k=top_k,
top_p=top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
......@@ -343,22 +349,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
if all_probs:
full_logits = output_context
else:
output_context = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits],1)
if all_probs:
full_logits = torch.cat([full_logits, output_context], 1)
if logprobs:
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
else:
output_context = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits],1)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
......@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
yield tokens, lengths, output_logits
else:
if mpu.is_pipeline_first_stage():
......@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None, None, None
yield tokens, None, None
else:
yield None, None, None, None
yield None, None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
......
......@@ -25,10 +25,10 @@ if __name__ == "__main__":
url = sys.argv[1]
while True:
sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len})
tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["sentences"][0])
print(resp_sentences["text"][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment