Commit 1e66b652 authored by Minseo Kang's avatar Minseo Kang
Browse files

feat: update pl version supports

parent 681d9aa3
......@@ -69,6 +69,7 @@ class SwinEncoder(nn.Module):
num_heads=[4, 8, 16, 32],
num_classes=0,
)
self.model.norm = None
# weight init with swin
if not name_or_path:
......
......@@ -44,6 +44,8 @@ class DonutModelPLModule(pl.LightningModule):
# encoder_layer=[2,2,14,2], decoder_layer=4, ...
)
)
self.pytorch_lightning_version_is_1 = int(pl.__version__[0]) < 2
self.num_of_loaders = len(self.config.dataset_name_or_paths)
def training_step(self, batch, batch_idx):
image_tensors, decoder_input_ids, decoder_labels = list(), list(), list()
......@@ -56,9 +58,16 @@ class DonutModelPLModule(pl.LightningModule):
decoder_labels = torch.cat(decoder_labels)
loss = self.model(image_tensors, decoder_input_ids, decoder_labels)[0]
self.log_dict({"train_loss": loss}, sync_dist=True)
if not self.pytorch_lightning_version_is_1:
self.log('loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=0):
def on_validation_epoch_start(self) -> None:
super().on_validation_epoch_start()
self.validation_step_outputs = [[] for _ in range(self.num_of_loaders)]
return
def validation_step(self, batch, batch_idx, dataloader_idx=0):
image_tensors, decoder_input_ids, prompt_end_idxs, answers = batch
decoder_prompts = pad_sequence(
[input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)],
......@@ -84,17 +93,16 @@ class DonutModelPLModule(pl.LightningModule):
self.print(f" Answer: {answer}")
self.print(f" Normed ED: {scores[0]}")
self.validation_step_outputs[dataloader_idx].append(scores)
return scores
def validation_epoch_end(self, validation_step_outputs):
num_of_loaders = len(self.config.dataset_name_or_paths)
if num_of_loaders == 1:
validation_step_outputs = [validation_step_outputs]
assert len(validation_step_outputs) == num_of_loaders
cnt = [0] * num_of_loaders
total_metric = [0] * num_of_loaders
val_metric = [0] * num_of_loaders
for i, results in enumerate(validation_step_outputs):
def on_validation_epoch_end(self):
assert len(self.validation_step_outputs) == self.num_of_loaders
cnt = [0] * self.num_of_loaders
total_metric = [0] * self.num_of_loaders
val_metric = [0] * self.num_of_loaders
for i, results in enumerate(self.validation_step_outputs):
for scores in results:
cnt[i] += len(scores)
total_metric[i] += np.sum(scores)
......@@ -136,13 +144,6 @@ class DonutModelPLModule(pl.LightningModule):
return LambdaLR(optimizer, lr_lambda)
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
items["exp_name"] = f"{self.config.get('exp_name', '')}"
items["exp_version"] = f"{self.config.get('exp_version', '')}"
return items
@rank_zero_only
def on_save_checkpoint(self, checkpoint):
save_path = Path(self.config.result_path) / self.config.exp_name / self.config.exp_version
......
......@@ -51,8 +51,34 @@ def save_config_file(config, path):
print(f"Config is saved at {save_path}")
class ProgressBar(pl.callbacks.TQDMProgressBar):
def __init__(self, config):
super().__init__()
self.enable = True
self.config = config
def disable(self):
self.enable = False
def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp_name"] = f"{self.config.get('exp_name', '')}"
items["exp_version"] = f"{self.config.get('exp_version', '')}"
return items
def set_seed(seed):
pytorch_lightning_version = int(pl.__version__[0])
if pytorch_lightning_version < 2:
pl.utilities.seed.seed_everything(seed, workers=True)
else:
import lightning_fabric
lightning_fabric.utilities.seed.seed_everything(seed, workers=True)
def train(config):
pl.utilities.seed.seed_everything(config.get("seed", 42), workers=True)
set_seed(config.get("seed", 42))
model_module = DonutModelPLModule(config)
data_module = DonutDataPLModule(config)
......@@ -111,11 +137,12 @@ def train(config):
mode="min",
)
bar = ProgressBar(config)
custom_ckpt = CustomCheckpointIO()
trainer = pl.Trainer(
resume_from_checkpoint=config.get("resume_from_checkpoint_path", None),
num_nodes=config.get("num_nodes", 1),
gpus=torch.cuda.device_count(),
devices=torch.cuda.device_count(),
strategy="ddp",
accelerator="gpu",
plugins=custom_ckpt,
......@@ -127,10 +154,10 @@ def train(config):
precision=16,
num_sanity_val_steps=0,
logger=logger,
callbacks=[lr_callback, checkpoint_callback],
callbacks=[lr_callback, checkpoint_callback, bar],
)
trainer.fit(model_module, data_module)
trainer.fit(model_module, data_module, ckpt_path=config.get("resume_from_checkpoint_path", None))
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment