Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nuomanzz
TangoFlux
Commits
5120a1c8
Commit
5120a1c8
authored
Jan 05, 2025
by
hungchiayu1
Browse files
remove unused args
parent
35b838f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
69 deletions
+4
-69
tangoflux/train_dpo.py
tangoflux/train_dpo.py
+4
-69
No files found.
tangoflux/train_dpo.py
View file @
5120a1c8
...
...
@@ -101,12 +101,7 @@ def parse_args():
"constant_with_warmup"
,
],
)
parser
.
add_argument
(
"--num_warmup_steps"
,
type
=
int
,
default
=
0
,
help
=
"Number of steps for the warmup in the lr scheduler."
,
)
parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
...
...
@@ -135,23 +130,7 @@ def parse_args():
help
=
"Save model after every how many epochs when checkpointing_steps is set to best."
,
)
parser
.
add_argument
(
"--resume_from_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"If the training should continue from a local checkpoint folder."
,
)
parser
.
add_argument
(
"--report_to"
,
type
=
str
,
default
=
"all"
,
help
=
(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser
.
add_argument
(
"--load_from_checkpoint"
,
...
...
@@ -159,12 +138,7 @@ def parse_args():
default
=
None
,
help
=
"Whether to continue training from a model weight"
,
)
parser
.
add_argument
(
"--audio_length"
,
type
=
float
,
default
=
30
,
help
=
"Audio duration"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -362,46 +336,8 @@ def main():
for
param
in
model
.
ref_transformer
.
parameters
():
param
.
requires_grad
=
False
@
torch
.
no_grad
()
def
initialize_or_update_ref_transformer
(
model
,
accelerator
:
Accelerator
,
alpha
=
0.5
):
"""
Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer.
Args:
model (torch.nn.Module): The main model containing the 'transformer' attribute.
accelerator (Accelerator): The Accelerator instance used to unwrap the model.
initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
If not provided, ref_transformer is initialized as a copy of transformer.
Returns:
torch.nn.Module: The model with the updated ref_transformer.
"""
# Unwrap the model to access the original underlying model
unwrapped_model
=
accelerator
.
unwrap_model
(
model
)
with
torch
.
no_grad
():
for
ref_param
,
model_param
in
zip
(
unwrapped_model
.
ref_transformer
.
parameters
(),
unwrapped_model
.
transformer
.
parameters
(),
):
average_param
=
alpha
*
ref_param
.
data
+
(
1
-
alpha
)
*
model_param
.
data
ref_param
.
data
.
copy_
(
average_param
)
unwrapped_model
.
ref_transformer
.
eval
()
unwrapped_model
.
ref_transformer
.
requires_grad_
=
False
for
param
in
unwrapped_model
.
ref_transformer
.
parameters
():
param
.
requires_grad
=
False
return
model
model
.
ref_transformer
=
copy
.
deepcopy
(
model
.
transformer
)
model
.
ref_transformer
.
requires_grad_
=
False
model
.
ref_transformer
.
eval
()
for
param
in
model
.
ref_transformer
.
parameters
():
param
.
requires_grad
=
False
optimizer
=
torch
.
optim
.
AdamW
(
optimizer_parameters
,
...
...
@@ -452,8 +388,7 @@ def main():
if
checkpointing_steps
is
not
None
and
checkpointing_steps
.
isdigit
():
checkpointing_steps
=
int
(
checkpointing_steps
)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
# Train!
total_batch_size
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment