Commit 10ef6f6c authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

improve torch compile logic

parent 82cbc3ad
......@@ -498,15 +498,7 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
text_encode_per_device_eval_batch_size: int = field(
default=8,
metadata={
"help": (
"TODO"
)
},
)
@dataclass
class DataCollatorEncodecWithPadding:
"""
......@@ -821,7 +813,7 @@ def main():
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile:
# TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor"))
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) #reduce-overhead
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
......@@ -1493,27 +1485,33 @@ def main():
# Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,):
model.eval()
eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval()
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
with torch.no_grad():
if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
encoder_outputs = eval_model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
else:
encoder_outputs = eval_model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs
with torch.no_grad():
outputs = model(**batch)
outputs = eval_model(**batch)
# CE (data) loss
ce_loss = outputs.loss
metrics = {"loss": ce_loss}
return metrics
def generate_step(batch):
model.eval()
batch.pop("decoder_attention_mask", None)
output_audios = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").generate(**batch, **gen_kwargs)
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").eval()
if training_args.torch_compile:
eval_model = model._orig_mod
output_audios = eval_model.generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
......@@ -1593,7 +1591,6 @@ def main():
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start
model.eval()
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment