"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a8694b8850ad310e1eebdfec3c90ef0ababce56f"
Unverified Commit b13c6c18 authored by sourabh112's avatar sourabh112 Committed by GitHub
Browse files

correcting group beam search function output score bug (#13211)

parent f689743e
...@@ -2403,6 +2403,9 @@ class GenerationMixin: ...@@ -2403,6 +2403,9 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for beam_group_idx in range(num_beam_groups): for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
...@@ -2411,9 +2414,6 @@ class GenerationMixin: ...@@ -2411,9 +2414,6 @@ class GenerationMixin:
# indices of beams of current group among all sentences in batch # indices of beams of current group among all sentences in batch
batch_group_indices = [] batch_group_indices = []
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
batch_group_indices.extend( batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment