"pipelines/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "80beca52c4269ac5dcd4955ae496f8c3a44d20ef"
Unverified Commit 34e07f4b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove unused attributes in `AssistedCandidateGenerator` (#29787)

remove unused attrs
parent e85654f5
...@@ -131,11 +131,9 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -131,11 +131,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
if assistant_model.config.is_encoder_decoder: if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder # both are encoder-decoder
self.input_ids_key = "decoder_input_ids" self.input_ids_key = "decoder_input_ids"
self.attention_key = "decoder_attention_mask"
elif "encoder_outputs" in assistant_kwargs: elif "encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper) # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
self.input_ids_key = "input_ids" self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
"decoder_attention_mask", "decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
...@@ -143,15 +141,8 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -143,15 +141,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
else: else:
# both are decoder-only # both are decoder-only
self.input_ids_key = "input_ids" self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"
# Prepare generation-related options. # Prepare generation-related options.
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
)
self.logits_processor = logits_processor self.logits_processor = logits_processor
self.generation_config = copy.deepcopy(generation_config) self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True self.generation_config.return_dict_in_generate = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment