Unverified Commit 01b6ec21 authored by Xinyang Li's avatar Xinyang Li Committed by GitHub
Browse files

fix validation option for dreambooth training example (#4317)

parent 92e5ddd2
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import copy
import gc import gc
import hashlib import hashlib
import itertools import itertools
...@@ -1116,7 +1117,7 @@ def main(args): ...@@ -1116,7 +1117,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
tracker_config = vars(args) tracker_config = vars(copy.deepcopy(args))
tracker_config.pop("validation_images") tracker_config.pop("validation_images")
accelerator.init_trackers("dreambooth", config=tracker_config) accelerator.init_trackers("dreambooth", config=tracker_config)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import copy
import gc import gc
import hashlib import hashlib
import itertools import itertools
...@@ -1067,7 +1068,7 @@ def main(args): ...@@ -1067,7 +1068,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
if accelerator.is_main_process: if accelerator.is_main_process:
tracker_config = vars(args) tracker_config = vars(copy.deepcopy(args))
tracker_config.pop("validation_images") tracker_config.pop("validation_images")
accelerator.init_trackers("dreambooth-lora", config=tracker_config) accelerator.init_trackers("dreambooth-lora", config=tracker_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment