Commit 453414da authored by rprenger's avatar rprenger
Browse files

Removing unnecessary --recompute path

parent f7fe3865
...@@ -189,40 +189,30 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -189,40 +189,30 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: types2use = None
output = forward_step(model, tokens, if counter == 0:
position_ids, tokens2use = tokens[:, :context_length]
attention_mask, positions2use = position_ids[:, :context_length]
tokentype_ids=type_ids, if type_ids is not None:
forward_method_parallel_output=False) types2use = type_ids[:, :context_length]
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
else: else:
types2use = None tokens2use = tokens[:, context_length - 1].view(
if counter == 0: batch_size, -1)
tokens2use = tokens[:, :context_length] positions2use = position_ids[:, context_length - 1].view(
positions2use = position_ids[:, :context_length] batch_size, -1)
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, context_length - 1].view(
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
if type_ids is not None: output, layer_past = forward_step(model, tokens2use,
types2use = type_ids[:, context_length - 1].view( positions2use,
batch_size, -1) attention_mask,
output, layer_past = forward_step(model, tokens2use, layer_past=layer_past,
positions2use, get_key_value=True,
attention_mask, tokentype_ids=types2use,
layer_past=layer_past, forward_method_parallel_output=False)
get_key_value=True, if mpu.is_pipeline_last_stage():
tokentype_ids=types2use, assert output is not None
forward_method_parallel_output=False) logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if args.greedy: if args.greedy:
......
...@@ -55,10 +55,6 @@ def add_text_generate_args(parser): ...@@ -55,10 +55,6 @@ def add_text_generate_args(parser):
help='Top k sampling.') help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024, group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.') help='Size of the output generated text.')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment