"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "d09b5ef4ef150adab31195761725eaba409f6343"
Unverified Commit 76e5af4c authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[pl_examples] revert deletion of optimizer_step (#5227)

parent c01480bb
...@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule): ...@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule):
self.opt = optimizer self.opt = optimizer
return [optimizer] return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb) return self.validation_step(batch, batch_nb)
......
...@@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer): ...@@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer):
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time() t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = time.time() - t0 / source_ids.shape[0] gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids) preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y) target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
......
...@@ -7,5 +7,6 @@ python distillation.py \ ...@@ -7,5 +7,6 @@ python distillation.py \
--learning_rate=3e-4 \ --learning_rate=3e-4 \
--do_train \ --do_train \
--do_predict \ --do_predict \
--fp16 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
$@ $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment