"src/nni_manager/vscode:/vscode.git/clone" did not exist on "b1a65978ca13ef8faa0aab0365c188f0db43c127"
Commit 90cda45e authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add past re-ordering for beam search

parent 6bca56fd
...@@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module): ...@@ -913,13 +913,18 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam]) beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states # re-order batch
input_ids = input_ids[beam_idx, :] input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys(): # re-order internal states
# if k != 'slen': if past:
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) reordered_past = []
for layer_past in past:
# copy the relevant beam idx past to past
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_past.append(torch.cat(reordered_layer_past, dim=1))
past = tuple(reordered_past)
# update current length # update current length
cur_len = cur_len + 1 cur_len = cur_len + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment