Commit 10ef6f6c authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

improve torch compile logic

parent 82cbc3ad
...@@ -498,14 +498,6 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments): ...@@ -498,14 +498,6 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
) )
}, },
) )
text_encode_per_device_eval_batch_size: int = field(
default=8,
metadata={
"help": (
"TODO"
)
},
)
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
...@@ -821,7 +813,7 @@ def main(): ...@@ -821,7 +813,7 @@ def main():
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile: if training_args.torch_compile:
# TODO(YL): add more compile modes? # TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor")) kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) #reduce-overhead
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
...@@ -1493,27 +1485,33 @@ def main(): ...@@ -1493,27 +1485,33 @@ def main():
# Define eval fn # Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,): def eval_step(batch, accelerator, autocast_kwargs,):
model.eval() eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval()
if mixed_precision == "fp16": if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models # fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed": with torch.no_grad():
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
encoder_outputs = eval_model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
else: else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = eval_model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = eval_model(**batch)
# CE (data) loss # CE (data) loss
ce_loss = outputs.loss ce_loss = outputs.loss
metrics = {"loss": ce_loss} metrics = {"loss": ce_loss}
return metrics return metrics
def generate_step(batch): def generate_step(batch):
model.eval()
batch.pop("decoder_attention_mask", None) batch.pop("decoder_attention_mask", None)
output_audios = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").generate(**batch, **gen_kwargs) eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").eval()
if training_args.torch_compile:
eval_model = model._orig_mod
output_audios = eval_model.generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0) output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios return output_audios
...@@ -1593,7 +1591,6 @@ def main(): ...@@ -1593,7 +1591,6 @@ def main():
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start train_time += time.time() - train_start
model.eval()
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
eval_metrics = [] eval_metrics = []
eval_preds = [] eval_preds = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment