Commit f2df39fa authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 8ecdd3ef
...@@ -1076,7 +1076,9 @@ def main(): ...@@ -1076,7 +1076,9 @@ def main():
and global_step % args.validation_steps == 0 and global_step % args.validation_steps == 0
and jax.process_index() == 0 and jax.process_index() == 0
): ):
_ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) _ = log_validation(
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
)
if global_step % args.logging_steps == 0 and jax.process_index() == 0: if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb": if args.report_to == "wandb":
...@@ -1108,7 +1110,9 @@ def main(): ...@@ -1108,7 +1110,9 @@ def main():
if args.validation_prompt is not None: if args.validation_prompt is not None:
if args.profile_validation: if args.profile_validation:
jax.profiler.start_trace(args.output_dir) jax.profiler.start_trace(args.output_dir)
image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype) image_logs = log_validation(
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
)
if args.profile_validation: if args.profile_validation:
jax.profiler.stop_trace() jax.profiler.stop_trace()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment