"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9c4f71a62adcfb1556a0bdd352cd4ed95c15b69d"
Unverified Commit 53e9aacc authored by Anatoly Belikov's avatar Anatoly Belikov Committed by GitHub
Browse files

log loss per image (#7278)



* log loss per image

* add commandline param for per image loss logging

* style

* debug-loss -> debug_loss

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 41424466
...@@ -425,6 +425,11 @@ def parse_args(input_args=None): ...@@ -425,6 +425,11 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are awailable in the dataset",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -603,6 +608,7 @@ def main(args): ...@@ -603,6 +608,7 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses. # The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is None: if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32) vae.to(accelerator.device, dtype=torch.float32)
else: else:
...@@ -890,13 +896,17 @@ def main(args): ...@@ -890,13 +896,17 @@ def main(args):
tokens_one, tokens_two = tokenize_captions(examples) tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two examples["input_ids_two"] = tokens_two
if args.debug_loss:
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
if fnames:
examples["filenames"] = fnames
return examples return examples
with accelerator.main_process_first(): with accelerator.main_process_first():
if args.max_train_samples is not None: if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms # Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train) train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
...@@ -905,7 +915,7 @@ def main(args): ...@@ -905,7 +915,7 @@ def main(args):
crop_top_lefts = [example["crop_top_lefts"] for example in examples] crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples]) input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples]) input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return { result = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"input_ids_one": input_ids_one, "input_ids_one": input_ids_one,
"input_ids_two": input_ids_two, "input_ids_two": input_ids_two,
...@@ -913,6 +923,11 @@ def main(args): ...@@ -913,6 +923,11 @@ def main(args):
"crop_top_lefts": crop_top_lefts, "crop_top_lefts": crop_top_lefts,
} }
filenames = [example["filenames"] for example in examples if "filenames" in example]
if filenames:
result["filenames"] = filenames
return result
# DataLoaders creation: # DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
...@@ -1105,7 +1120,9 @@ def main(args): ...@@ -1105,7 +1120,9 @@ def main(args):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean() loss = loss.mean()
if args.debug_loss and "filenames" in batch:
for fname in batch["filenames"]:
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
# Gather the losses across all processes for logging (if we use distributed training). # Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps train_loss += avg_loss.item() / args.gradient_accumulation_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment