Unverified Commit 1dd0ac94 authored by Radamés Ajna's avatar Radamés Ajna Committed by GitHub
Browse files

[DPO Training] pass tracker name as argument (#6542)

pass tracker name as argumentw
parent c6b04589
...@@ -414,6 +414,12 @@ def parse_args(input_args=None): ...@@ -414,6 +414,12 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--tracker_name",
type=str,
default="diffusion-dpo-lora",
help=("The name of the tracker to report results to."),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -726,7 +732,7 @@ def main(args): ...@@ -726,7 +732,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("diffusion-dpo-lora", config=vars(args)) accelerator.init_trackers(args.tracker_name, config=vars(args))
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
......
...@@ -429,6 +429,12 @@ def parse_args(input_args=None): ...@@ -429,6 +429,12 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--tracker_name",
type=str,
default="diffusion-dpo-lora-sdxl",
help=("The name of the tracker to report results to."),
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -821,7 +827,7 @@ def main(args): ...@@ -821,7 +827,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args)) accelerator.init_trackers(args.tracker_name, config=vars(args))
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment