Commit 5120a1c8 authored by hungchiayu1's avatar hungchiayu1
Browse files

remove unused args

parent 35b838f1
......@@ -101,12 +101,7 @@ def parse_args():
"constant_with_warmup",
],
)
parser.add_argument(
"--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--adam_epsilon",
type=float,
......@@ -135,23 +130,7 @@ def parse_args():
help="Save model after every how many epochs when checkpointing_steps is set to best.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a local checkpoint folder.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--load_from_checkpoint",
......@@ -159,12 +138,7 @@ def parse_args():
default=None,
help="Whether to continue training from a model weight",
)
parser.add_argument(
"--audio_length",
type=float,
default=30,
help="Audio duration",
)
args = parser.parse_args()
......@@ -362,46 +336,8 @@ def main():
for param in model.ref_transformer.parameters():
param.requires_grad = False
@torch.no_grad()
def initialize_or_update_ref_transformer(
model, accelerator: Accelerator, alpha=0.5
):
"""
Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer.
Args:
model (torch.nn.Module): The main model containing the 'transformer' attribute.
accelerator (Accelerator): The Accelerator instance used to unwrap the model.
initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
If not provided, ref_transformer is initialized as a copy of transformer.
Returns:
torch.nn.Module: The model with the updated ref_transformer.
"""
# Unwrap the model to access the original underlying model
unwrapped_model = accelerator.unwrap_model(model)
with torch.no_grad():
for ref_param, model_param in zip(
unwrapped_model.ref_transformer.parameters(),
unwrapped_model.transformer.parameters(),
):
average_param = alpha * ref_param.data + (1 - alpha) * model_param.data
ref_param.data.copy_(average_param)
unwrapped_model.ref_transformer.eval()
unwrapped_model.ref_transformer.requires_grad_ = False
for param in unwrapped_model.ref_transformer.parameters():
param.requires_grad = False
return model
model.ref_transformer = copy.deepcopy(model.transformer)
model.ref_transformer.requires_grad_ = False
model.ref_transformer.eval()
for param in model.ref_transformer.parameters():
param.requires_grad = False
optimizer = torch.optim.AdamW(
optimizer_parameters,
......@@ -452,8 +388,7 @@ def main():
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
# Train!
total_batch_size = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment