Unverified Commit 88e84186 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional: part 4 `examples` (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
parent 372ab9cd
......@@ -17,7 +17,7 @@
import logging
import torch
import torch.nn as nn
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
......@@ -270,6 +270,7 @@ class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
from transformers import AlbertTokenizer
from pabee import AlbertForSequenceClassificationWithPabee
from torch import nn
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
......
......@@ -294,6 +294,7 @@ class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
from transformers import BertTokenizer, BertForSequenceClassification
from pabee import BertForSequenceClassificationWithPabee
from torch import nn
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
......
......@@ -25,6 +25,7 @@ import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -117,11 +118,11 @@ def train(args, train_dataset, model, tokenizer):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
......@@ -203,9 +204,9 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
......@@ -291,8 +292,8 @@ def evaluate(args, model, tokenizer, prefix="", patience=0):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
......@@ -26,6 +26,7 @@ from datetime import datetime
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler, Subset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
......@@ -415,11 +416,11 @@ def main():
# Distributed and parallel training
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True)
......
......@@ -10,6 +10,7 @@ from datetime import datetime
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tqdm import tqdm
......@@ -352,11 +353,11 @@ def main():
# Distributed and parallel training
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True)
......
......@@ -9,6 +9,7 @@ import time
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
......@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
......@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa
# multi-gpu eval
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import RobertaConfig
......
......@@ -21,8 +21,7 @@ import time
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
......@@ -412,8 +411,8 @@ class Distiller:
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
......@@ -492,9 +491,9 @@ class Distiller:
self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
......
......@@ -24,8 +24,7 @@ import timeit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -138,11 +137,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
......@@ -232,15 +231,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = (
loss_fct(
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
F.softmax(start_logits_tea / args.temperature, dim=-1),
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
loss_end = (
loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1),
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
......@@ -262,9 +261,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
......@@ -326,8 +325,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
......@@ -11,6 +11,7 @@ import torch
import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
......@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module):
class RetrievalQAEmbedder(nn.Module):
def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder
self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
self.project_q = nn.Linear(dim, self.output_dim, bias=False)
self.project_a = nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing
......
......@@ -25,7 +25,6 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops import RoIPool
from torchvision.ops.boxes import batched_nms, nms
......@@ -85,7 +84,7 @@ def pad_list_tensors(
too_small = True
tensor_i = tensor_i.unsqueeze(-1)
assert isinstance(tensor_i, torch.Tensor)
tensor_i = F.pad(
tensor_i = nn.functional.pad(
input=tensor_i,
pad=(0, 0, 0, max_detections - preds_per_image[i]),
mode="constant",
......@@ -701,7 +700,7 @@ class RPNOutputs(object):
# Main Classes
class Conv2d(torch.nn.Conv2d):
class Conv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
......@@ -712,9 +711,9 @@ class Conv2d(torch.nn.Conv2d):
def forward(self, x):
if x.numel() == 0 and self.training:
assert not isinstance(self.norm, torch.nn.SyncBatchNorm)
assert not isinstance(self.norm, nn.SyncBatchNorm)
if x.numel() == 0:
assert not isinstance(self.norm, torch.nn.GroupNorm)
assert not isinstance(self.norm, nn.GroupNorm)
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip(
......@@ -752,7 +751,7 @@ class LastLevelMaxPool(nn.Module):
self.in_feature = "p5"
def forward(self, x):
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
return [nn.functional.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
class LastLevelP6P7(nn.Module):
......@@ -769,7 +768,7 @@ class LastLevelP6P7(nn.Module):
def forward(self, c5):
p6 = self.p6(c5)
p7 = self.p7(F.relu(p6))
p7 = self.p7(nn.functional.relu(p6))
return [p6, p7]
......@@ -790,11 +789,11 @@ class BasicStem(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = F.relu_(x)
x = nn.functional.relu_(x)
if self.caffe_maxpool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
else:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x
@property
......@@ -881,10 +880,10 @@ class BottleneckBlock(ResNetBlockBase):
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = nn.functional.relu_(out)
out = self.conv2(out)
out = F.relu_(out)
out = nn.functional.relu_(out)
out = self.conv3(out)
......@@ -894,7 +893,7 @@ class BottleneckBlock(ResNetBlockBase):
shortcut = x
out += shortcut
out = F.relu_(out)
out = nn.functional.relu_(out)
return out
......@@ -1159,7 +1158,7 @@ class ROIOutputs(object):
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
def _predict_objs(self, obj_logits, preds_per_image):
probs = F.softmax(obj_logits, dim=-1)
probs = nn.functional.softmax(obj_logits, dim=-1)
probs = probs.split(preds_per_image, dim=0)
return probs
......@@ -1490,7 +1489,7 @@ class RPNHead(nn.Module):
pred_objectness_logits = []
pred_anchor_deltas = []
for x in features:
t = F.relu(self.conv(x))
t = nn.functional.relu(self.conv(x))
pred_objectness_logits.append(self.objectness_logits(t))
pred_anchor_deltas.append(self.anchor_deltas(t))
return pred_objectness_logits, pred_anchor_deltas
......@@ -1650,7 +1649,7 @@ class FastRCNNOutputLayers(nn.Module):
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
roi_features = self.fc_attr(roi_features)
roi_features = F.relu(roi_features)
roi_features = nn.functional.relu(roi_features)
attr_scores = self.attr_score(roi_features)
return scores, attr_scores, proposal_deltas
else:
......
......@@ -20,8 +20,8 @@ from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from utils import img_tensorize
......@@ -63,7 +63,9 @@ class ResizeShortestEdge:
img = np.asarray(pil_image)
else:
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
img = nn.functional.interpolate(
img, (newh, neww), mode=self.interp_method, align_corners=False
).squeeze(0)
img_augs.append(img)
return img_augs
......@@ -85,7 +87,7 @@ class Preprocess:
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
image_sizes = [im.shape[-2:] for im in images]
images = [
F.pad(
nn.functional.pad(
im,
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
value=self.pad_value,
......
......@@ -25,8 +25,8 @@ import random
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -107,11 +107,11 @@ def train(args, train_dataset, model, tokenizer, criterion):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
......@@ -166,9 +166,9 @@ def train(args, train_dataset, model, tokenizer, criterion):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
......@@ -248,8 +248,8 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
)
# multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
......@@ -19,10 +19,10 @@ import os
from collections import Counter
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
......
......@@ -75,7 +75,7 @@
"quantized_model = torch.quantization.quantize_dynamic(\n",
" model=model,\n",
" qconfig_spec = {\n",
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" },\n",
" dtype=torch.qint8,\n",
" )\n",
......
......@@ -23,7 +23,6 @@ import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
......@@ -104,4 +103,4 @@ class MaskedLinear(nn.Linear):
# Mask weights with computed mask
weight_thresholded = mask * self.weight
# Compute output (linear layer) with masked weights
return F.linear(input, weight_thresholded, self.bias)
return nn.functional.linear(input, weight_thresholded, self.bias)
......@@ -24,8 +24,7 @@ import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -168,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
......@@ -287,9 +286,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
)
loss_logits = (
F.kl_div(
input=F.log_softmax(logits_stu / args.temperature, dim=-1),
target=F.softmax(logits_tea / args.temperature, dim=-1),
nn.functional.kl_div(
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
......@@ -320,9 +319,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
and (step + 1) == len(epoch_iterator)
):
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step)
......@@ -436,8 +435,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
......@@ -25,8 +25,7 @@ import timeit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -176,11 +175,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
......@@ -308,17 +307,17 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
)
loss_start = (
F.kl_div(
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1),
target=F.softmax(start_logits_tea / args.temperature, dim=-1),
nn.functional.kl_div(
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
)
loss_end = (
F.kl_div(
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1),
target=F.softmax(end_logits_tea / args.temperature, dim=-1),
nn.functional.kl_div(
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
reduction="batchmean",
)
* (args.temperature ** 2)
......@@ -346,9 +345,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step)
......@@ -454,8 +453,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
......
import torch
from torch import nn
class ClassificationHead(torch.nn.Module):
class ClassificationHead(nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super().__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
# self.mlp1 = nn.Linear(embed_size, embed_size)
# self.mlp2 = (nn.Linear(embed_size, class_size))
self.mlp = nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits
......@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import trange
from pplm_classification_head import ClassificationHead
......@@ -160,7 +160,7 @@ def perturb_past(
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
......@@ -173,7 +173,7 @@ def perturb_past(
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
......@@ -195,7 +195,7 @@ def perturb_past(
kl_loss = 0.0
if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach()
......@@ -527,10 +527,10 @@ def generate_text_pplm(
else:
pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
......@@ -541,7 +541,7 @@ def generate_text_pplm(
# Fuse the modified model and original model
if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
......@@ -552,7 +552,7 @@ def generate_text_pplm(
else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
# sample or greedy
if sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment