"third_party/vscode:/vscode.git/clone" did not exist on "03badcd3e91c7f54ec42f0b13c675070ef40c016"
Unverified Commit 88e84186 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional: part 4 `examples` (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
parent 372ab9cd
...@@ -23,10 +23,10 @@ import time ...@@ -23,10 +23,10 @@ import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.utils.data as data import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer from nltk.tokenize.treebank import TreebankWordDetokenizer
from torch import nn
from torchtext import data as torchtext_data from torchtext import data as torchtext_data
from torchtext import datasets from torchtext import datasets
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha ...@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100 max_length_seq = 100
class Discriminator(torch.nn.Module): class Discriminator(nn.Module):
"""Transformer encoder followed by a Classification Head""" """Transformer encoder followed by a Classification Head"""
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"): def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
...@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module): ...@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden = self.avg_representation(x.to(self.device)) avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden) logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1) probs = nn.functional.log_softmax(logits, dim=-1)
return probs return probs
...@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, ...@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer.zero_grad() optimizer.zero_grad()
output_t = discriminator(input_t) output_t = discriminator(input_t)
loss = F.nll_loss(output_t, target_t) loss = nn.functional.nll_loss(output_t, target_t)
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
optimizer.step() optimizer.step()
...@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"): ...@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t, target_t = input_t.to(device), target_t.to(device) input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t) output_t = discriminator(input_t)
# sum up batch loss # sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability # get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True) pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
......
...@@ -8,6 +8,7 @@ from pathlib import Path ...@@ -8,6 +8,7 @@ from pathlib import Path
import pytest import pytest
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
import lightning_base import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
...@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) lprobs = nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss( smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
) )
......
...@@ -10,7 +10,6 @@ from typing import List ...@@ -10,7 +10,6 @@ from typing import List
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main from finetune import main as ft_main
...@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule): ...@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
assert t_logits_slct.size() == s_logits_slct.size() assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = ( loss_ce = (
self.ce_loss_fct( self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1), nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1), nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
) )
* (self.temperature) ** 2 * (self.temperature) ** 2
) )
...@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule): ...@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
assert lm_logits.shape[-1] == self.model.config.vocab_size assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else: else:
lprobs = F.log_softmax(lm_logits, dim=-1) lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss( student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
...@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule): ...@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
teacher_states = torch.stack([hidden_states_T[j] for j in matches]) teacher_states = torch.stack([hidden_states_T[j] for j in matches])
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}" assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
if normalize_hidden: if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:]) student_states = nn.functional.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none") mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none")
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
return masked_mse return masked_mse
......
...@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple ...@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
...@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer): ...@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
lm_logits = outputs["logits"] lm_logits = outputs["logits"]
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
assert lm_logits.shape[-1] == self.vocab_size assert lm_logits.shape[-1] == self.vocab_size
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else: else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss( loss, nll_loss = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
......
...@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union ...@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import datasets import datasets
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from packaging import version from packaging import version
from torch import nn
import librosa import librosa
from lang_trans import arabic from lang_trans import arabic
......
...@@ -5,9 +5,9 @@ from dataclasses import dataclass, field ...@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
from packaging import version from packaging import version
from torch import nn
import librosa import librosa
from transformers import ( from transformers import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment