Commit d46aa964 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'api_change' into 'main'

API improvements.

See merge request ADLR/megatron-lm!337
parents b31e1296 0694205c
...@@ -39,10 +39,19 @@ class MegatronGenerate(Resource): ...@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now()) print("current time: ", datetime.datetime.now())
if not "prompts" in request.get_json():
return "prompts argument required", 400
sentences = request.get_json()["sentences"] if "max_len" in request.get_json():
if len(sentences) > 128: return "max_len is no longer used. Replace with tokens_to_generate", 400
return "Maximum number of sentences is 128", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json(): if "tokens_to_generate" in request.get_json():
...@@ -52,11 +61,11 @@ class MegatronGenerate(Resource): ...@@ -52,11 +61,11 @@ class MegatronGenerate(Resource):
if tokens_to_generate < 1: if tokens_to_generate < 1:
return "tokens_to_generate must be an integer greater than 0" return "tokens_to_generate must be an integer greater than 0"
all_probs = False logprobs = False
if "all_probs" in request.get_json(): if "logprobs" in request.get_json():
all_probs = request.get_json()["all_probs"] logprobs = request.get_json()["logprobs"]
if not isinstance(all_probs, bool): if not isinstance(logprobs, bool):
return "all_probs must be a boolean value" return "logprobs must be a boolean value"
temperature = args.temperature temperature = args.temperature
if "temperature" in request.get_json(): if "temperature" in request.get_json():
...@@ -66,6 +75,22 @@ class MegatronGenerate(Resource): ...@@ -66,6 +75,22 @@ class MegatronGenerate(Resource):
if not (0.0 < temperature <= 100.0): if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0" return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 < top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = args.top_p
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if not (0 < top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
add_BOS = False add_BOS = False
if "add_BOS" in request.get_json(): if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"] add_BOS = request.get_json()["add_BOS"]
...@@ -74,24 +99,24 @@ class MegatronGenerate(Resource): ...@@ -74,24 +99,24 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) response, response_seg, response_logprobs = generate(self.model,
prompts,
if all_probs: tokens_to_generate,
return jsonify({"sentences": resp_sentences, logprobs,
"segments": resp_sentences_seg, temperature,
"logits": output_logits, top_k,
"all_logits": full_logits, top_p,
"tokens": tokens}) add_BOS)
return jsonify({"sentences": resp_sentences, return jsonify({"text": response,
"segments": resp_sentences_seg, "segments": response_seg,
"logits": output_logits}) "logprobs": response_logprobs})
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__, static_url_path='') self.app = Flask(__name__, static_url_path='')
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url): def run(self, url):
self.app.run(url, threaded=True, debug=False) self.app.run(url, threaded=True, debug=False)
...@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS): ...@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
""" """
Needs to be synced up with receive_generate_info Needs to be synced up with receive_generate_info
""" """
# Send the sizes of the tensors # Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p]
input_info_tensor = torch.cuda.FloatTensor(input_info) input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0) torch.distributed.broadcast(input_info_tensor, 0)
...@@ -125,13 +125,15 @@ def receive_generate_info(): ...@@ -125,13 +125,15 @@ def receive_generate_info():
""" """
Needs to be synced up with send_generate_info Needs to be synced up with send_generate_info
""" """
input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device()) input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0) torch.distributed.broadcast(input_info_tensor, 0)
batch_size = int(input_info_tensor[0].item()) batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item()) seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item()) tokens_to_generate = int(input_info_tensor[2].item())
all_probs = int(input_info_tensor[3].item()) logprobs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item()) temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
...@@ -140,56 +142,53 @@ def receive_generate_info(): ...@@ -140,56 +142,53 @@ def receive_generate_info():
torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model,
context_tokens_tensor,
context_length_tensor, context_length_tensor,
attention_mask, position_ids, attention_mask,
position_ids,
tokens_to_generate, tokens_to_generate,
all_probs, logprobs,
temperature=temperature) temperature,
for tokens, lengths, output_logits, full_logits in batch_token_iterator: top_k,
top_p)
for tokens, lengths, output_logits in batch_token_iterator:
context_length += 1 context_length += 1
if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group)
else: if logprobs:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_last_stage():
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group) torch.distributed.broadcast(output_logits, src, group)
if all_probs: else:
args = get_args() if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
torch.distributed.broadcast(full_logits, src, group) torch.distributed.broadcast(output_logits, src, group)
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False): def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
else: else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
if output is not None: if output is not None:
decode_tokens, output_logits, full_logits = output decode_tokens, output_logits = output
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe ...@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
resp_sentences_seg = [] resp_sentences_seg = []
decode_tokens = decode_tokens.cpu().numpy().tolist() decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
for i, decode_token in enumerate(decode_tokens):
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
words = [] words = []
for token in decode_token: for token in decode_token:
...@@ -205,12 +205,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe ...@@ -205,12 +205,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace') word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
words.append(word) words.append(word)
resp_sentences_seg.append(words) resp_sentences_seg.append(words)
output_logits = output_logits.cpu().numpy().tolist() if logprobs:
if all_probs: output_logits = output_logits.cpu().numpy().tolist()
full_logits = full_logits.cpu().numpy().tolist() return resp_sentences, resp_sentences_seg, output_logits
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
def generate_samples_eval(model, context, max_gen_length, eos_token_id): def generate_samples_eval(model, context, max_gen_length, eos_token_id):
""" """
...@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
return output_tensor return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model,
attention_mask, position_ids, context_tokens,
tokens_to_generate, all_probs=False, type_ids=None, temperature=None): context_lengths,
attention_mask,
position_ids,
tokens_to_generate,
logprobs,
temperature,
top_k,
top_p,
type_ids=None):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
else: else:
logits = logits.float() logits = logits.float()
logits /= temperature logits /= temperature
logits = top_k_logits(logits, top_k=args.top_k, logits = top_k_logits(logits, top_k=top_k,
top_p=args.top_p) top_p=top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length started = context_lengths <= context_length
...@@ -343,22 +349,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -343,22 +349,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = switch( new_tokens = switch(
tokens[:, context_length].view(-1), prev, started) tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
if output_logits is None: if logprobs:
output_context = F.log_softmax(output[:, :context_length, :], 2) if output_logits is None:
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_context = F.log_softmax(output[:, :context_length, :], 2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
if all_probs: output_logits = torch.gather(output_context, 2, indices).squeeze(2)
full_logits = output_context else:
else: output_context = F.log_softmax(output, 2)
output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1)
output_logits = torch.cat([output_logits, new_output_logits],1)
if all_probs:
full_logits = torch.cat([full_logits, output_context], 1)
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
...@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group() group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group) torch.distributed.broadcast(done, src, group)
if all_probs: yield tokens, lengths, output_logits
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else: else:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens = torch.empty_like(tokens[:, context_length]) new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group) torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
yield tokens, None, None, None yield tokens, None, None
else: else:
yield None, None, None, None yield None, None, None
done = torch.cuda.ByteTensor([0]) done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
......
...@@ -25,10 +25,10 @@ if __name__ == "__main__": ...@@ -25,10 +25,10 @@ if __name__ == "__main__":
url = sys.argv[1] url = sys.argv[1]
while True: while True:
sentence = raw_input("Enter prompt: ") sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: ")) tokens_to_generate = int(input("Enter number of tokens to generate: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len}) data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
req = PutRequest(url, data, {'Content-Type': 'application/json'}) req = PutRequest(url, data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req) response = urllib2.urlopen(req)
resp_sentences = json.load(response) resp_sentences = json.load(response)
print("Megatron Response: ") print("Megatron Response: ")
print(resp_sentences["sentences"][0]) print(resp_sentences["text"][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment