Commit 80da6b4c authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

fix eval when fp16 + remove useless code

parent e51113f9
......@@ -192,22 +192,6 @@ def log_pred(
]},
step=step)
#### ARGUMENTS
class StableSpeechTrainer(Seq2SeqTrainer):
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length, tensor.shape[2]), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[1]] = tensor
return padded_tensor
@dataclass
class ModelArguments:
"""
......@@ -1349,8 +1333,13 @@ def main():
return ce_loss, metrics
# Define eval fn
def eval_step(batch):
def eval_step(batch, accelerator, autocast_kwargs,):
model.eval()
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs
with torch.no_grad():
outputs = model(**batch)
......@@ -1361,7 +1350,7 @@ def main():
def generate_step(batch):
model.eval()
output_audios = accelerator.unwrap_model(model).generate(**batch, **gen_kwargs)
output_audios = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
......@@ -1470,7 +1459,7 @@ def main():
disable=not accelerator.is_local_main_process,
):
# Model forward
eval_metric = eval_step(batch)
eval_metric = eval_step(batch, accelerator, autocast_kwargs)
eval_metric = accelerator.gather_for_metrics(eval_metric)
eval_metrics.append(eval_metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment