Commit c81bf127 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'memory_save' into 'main'

Reducing memory usage during inference

See merge request ADLR/megatron-lm!320
parents 8fe6f9bf faf58b77
...@@ -85,21 +85,22 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): ...@@ -85,21 +85,22 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
return logits return logits
def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, max_len):
context_lengths = [] context_lengths = []
max_context_length = max([len(tokens) for tokens in batch])
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < max_context_length + max_len:
tokens.extend([pad_id] * (args.seq_length - context_length)) tokens.extend([pad_id] * (max_context_length + max_len - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def tokenize_batch(sentences): def tokenize_batch(sentences, max_len):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens, context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args) tokenizer.eod, max_len)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor return context_tokens_tensor, context_length_tensor
...@@ -178,11 +179,11 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -178,11 +179,11 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
def generate(model, sentences=None, max_len=0, all_probs=False): def generate(model, sentences=None, max_len=0, all_probs=False):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, max_len)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
else: else:
context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs)
if output is not None: if output is not None:
decode_tokens, output_logits, full_logits = output decode_tokens, output_logits, full_logits = output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment