Unverified Commit 88e84186 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional: part 4 `examples` (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
parent 372ab9cd
......@@ -23,10 +23,10 @@ import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torch import nn
from torchtext import data as torchtext_data
from torchtext import datasets
from tqdm import tqdm, trange
......@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100
class Discriminator(torch.nn.Module):
class Discriminator(nn.Module):
"""Transformer encoder followed by a Classification Head"""
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
......@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1)
probs = nn.functional.log_softmax(logits, dim=-1)
return probs
......@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer.zero_grad()
output_t = discriminator(input_t)
loss = F.nll_loss(output_t, target_t)
loss = nn.functional.nll_loss(output_t, target_t)
loss.backward(retain_graph=True)
optimizer.step()
......@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t)
# sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
......
......@@ -8,6 +8,7 @@ from pathlib import Path
import pytest
import pytorch_lightning as pl
import torch
from torch import nn
import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
......@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
lprobs = nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
)
......
......@@ -10,7 +10,6 @@ from typing import List
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
......@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
......@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else:
lprobs = F.log_softmax(lm_logits, dim=-1)
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)
......@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none")
student_states = nn.functional.layer_norm(student_states, student_states.shape[1:])
teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:])
mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none")
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
return masked_mse
......
......@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
......@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
lm_logits = outputs["logits"]
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
assert lm_logits.shape[-1] == self.vocab_size
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
)
......
......@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import datasets
import numpy as np
import torch
import torch.nn as nn
from packaging import version
from torch import nn
import librosa
from lang_trans import arabic
......
......@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset
from packaging import version
from torch import nn
import librosa
from transformers import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment