Unverified Commit 38a1b03f authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Remove unhelpful bart warning (#7391)

parent 5ff0d6d7
...@@ -550,13 +550,8 @@ class BartDecoder(nn.Module): ...@@ -550,13 +550,8 @@ class BartDecoder(nn.Module):
positions = self.embed_positions(input_ids, use_cache=use_cache) positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache: if use_cache:
if input_ids.shape[1] != 1 or past_key_values is None:
# if you make this an AssertionError, test_benchmark breaks.
warnings.warn("pass decoder_past_key_value_states to use_cache")
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them positions = positions[:, -1:]
# assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids) * self.embed_scale x = self.embed_tokens(input_ids) * self.embed_scale
x += positions x += positions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment