Unverified Commit e45bc143 authored by jprivera44's avatar jprivera44 Committed by GitHub
Browse files

Beam search type (#24288)

* test check in

* adding in type hint fix on beam search

* fixed code quality issue
parent 1a113fcf
......@@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -211,7 +211,7 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment