Unverified Commit bf3dfd11 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CI / generate: batch size computation compatible with all models (#29671)

parent 00c1d87a
...@@ -1949,11 +1949,9 @@ class GenerationMixin: ...@@ -1949,11 +1949,9 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
batch_size, cur_len = ( batch_size, cur_len = input_ids.shape
model_kwargs["attention_mask"].shape if "inputs_embeds" in model_kwargs:
if model_kwargs.get("attention_mask", None) is not None cur_len = model_kwargs["inputs_embeds"].shape[1]
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
...@@ -2398,12 +2396,10 @@ class GenerationMixin: ...@@ -2398,12 +2396,10 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
this_peer_finished = False this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
...@@ -2686,12 +2682,10 @@ class GenerationMixin: ...@@ -2686,12 +2682,10 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
this_peer_finished = False this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
...@@ -4461,11 +4455,9 @@ class GenerationMixin: ...@@ -4461,11 +4455,9 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
batch_size, cur_len = batch_size, cur_len = ( batch_size, cur_len = input_ids.shape
model_kwargs["attention_mask"].shape if "inputs_embeds" in model_kwargs:
if model_kwargs.get("attention_mask", None) is not None cur_len = model_kwargs["inputs_embeds"].shape[1]
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment