Unverified Commit 88e84186 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional: part 4 `examples` (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
parent 372ab9cd
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import logging import logging
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -270,6 +270,7 @@ class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel): ...@@ -270,6 +270,7 @@ class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
from transformers import AlbertTokenizer from transformers import AlbertTokenizer
from pabee import AlbertForSequenceClassificationWithPabee from pabee import AlbertForSequenceClassificationWithPabee
from torch import nn
import torch import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
......
...@@ -294,6 +294,7 @@ class BertForSequenceClassificationWithPabee(BertPreTrainedModel): ...@@ -294,6 +294,7 @@ class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
from transformers import BertTokenizer, BertForSequenceClassification from transformers import BertTokenizer, BertForSequenceClassification
from pabee import BertForSequenceClassificationWithPabee from pabee import BertForSequenceClassificationWithPabee
from torch import nn
import torch import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
......
...@@ -25,6 +25,7 @@ import random ...@@ -25,6 +25,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -117,11 +118,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -117,11 +118,11 @@ def train(args, train_dataset, model, tokenizer):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
...@@ -203,9 +204,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -203,9 +204,9 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
...@@ -291,8 +292,8 @@ def evaluate(args, model, tokenizer, prefix="", patience=0): ...@@ -291,8 +292,8 @@ def evaluate(args, model, tokenizer, prefix="", patience=0):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
...@@ -26,6 +26,7 @@ from datetime import datetime ...@@ -26,6 +26,7 @@ from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler, Subset from torch.utils.data import DataLoader, SequentialSampler, Subset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
...@@ -415,11 +416,11 @@ def main(): ...@@ -415,11 +416,11 @@ def main():
# Distributed and parallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
elif args.n_gpu > 1: elif args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Print/save training arguments # Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
......
...@@ -10,6 +10,7 @@ from datetime import datetime ...@@ -10,6 +10,7 @@ from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tqdm import tqdm from tqdm import tqdm
...@@ -352,11 +353,11 @@ def main(): ...@@ -352,11 +353,11 @@ def main():
# Distributed and parallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
elif args.n_gpu > 1: elif args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Print/save training arguments # Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
......
...@@ -9,6 +9,7 @@ import time ...@@ -9,6 +9,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False): ...@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
...@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False): ...@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
...@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa ...@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from transformers import RobertaConfig from transformers import RobertaConfig
......
...@@ -21,8 +21,7 @@ import time ...@@ -21,8 +21,7 @@ import time
import psutil import psutil
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -412,8 +411,8 @@ class Distiller: ...@@ -412,8 +411,8 @@ class Distiller:
loss_ce = ( loss_ce = (
self.ce_loss_fct( self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1), nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1), nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
) )
* (self.temperature) ** 2 * (self.temperature) ** 2
) )
...@@ -492,9 +491,9 @@ class Distiller: ...@@ -492,9 +491,9 @@ class Distiller:
self.iter() self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0: if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16: if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() self.scheduler.step()
......
...@@ -24,8 +24,7 @@ import timeit ...@@ -24,8 +24,7 @@ import timeit
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -138,11 +137,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -138,11 +137,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
...@@ -232,15 +231,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -232,15 +231,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
loss_fct = nn.KLDivLoss(reduction="batchmean") loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = ( loss_start = (
loss_fct( loss_fct(
F.log_softmax(start_logits_stu / args.temperature, dim=-1), nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
F.softmax(start_logits_tea / args.temperature, dim=-1), nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
loss_end = ( loss_end = (
loss_fct( loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1), nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1), nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
...@@ -262,9 +261,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -262,9 +261,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
...@@ -326,8 +325,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -326,8 +325,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate # multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm from tqdm import tqdm
...@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset): ...@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
return self.make_example(idx % self.data.num_rows) return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module): class RetrievalQAEmbedder(nn.Module):
def __init__(self, sent_encoder, dim): def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__() super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder self.sent_encoder = sent_encoder
self.output_dim = 128 self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False) self.project_q = nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False) self.project_a = nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean") self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1): def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing # reproduces BERT forward pass with checkpointing
......
...@@ -25,7 +25,6 @@ from typing import Dict, List, Tuple ...@@ -25,7 +25,6 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops import RoIPool from torchvision.ops import RoIPool
from torchvision.ops.boxes import batched_nms, nms from torchvision.ops.boxes import batched_nms, nms
...@@ -85,7 +84,7 @@ def pad_list_tensors( ...@@ -85,7 +84,7 @@ def pad_list_tensors(
too_small = True too_small = True
tensor_i = tensor_i.unsqueeze(-1) tensor_i = tensor_i.unsqueeze(-1)
assert isinstance(tensor_i, torch.Tensor) assert isinstance(tensor_i, torch.Tensor)
tensor_i = F.pad( tensor_i = nn.functional.pad(
input=tensor_i, input=tensor_i,
pad=(0, 0, 0, max_detections - preds_per_image[i]), pad=(0, 0, 0, max_detections - preds_per_image[i]),
mode="constant", mode="constant",
...@@ -701,7 +700,7 @@ class RPNOutputs(object): ...@@ -701,7 +700,7 @@ class RPNOutputs(object):
# Main Classes # Main Classes
class Conv2d(torch.nn.Conv2d): class Conv2d(nn.Conv2d):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None) norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None) activation = kwargs.pop("activation", None)
...@@ -712,9 +711,9 @@ class Conv2d(torch.nn.Conv2d): ...@@ -712,9 +711,9 @@ class Conv2d(torch.nn.Conv2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and self.training: if x.numel() == 0 and self.training:
assert not isinstance(self.norm, torch.nn.SyncBatchNorm) assert not isinstance(self.norm, nn.SyncBatchNorm)
if x.numel() == 0: if x.numel() == 0:
assert not isinstance(self.norm, torch.nn.GroupNorm) assert not isinstance(self.norm, nn.GroupNorm)
output_shape = [ output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1 (i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip( for i, p, di, k, s in zip(
...@@ -752,7 +751,7 @@ class LastLevelMaxPool(nn.Module): ...@@ -752,7 +751,7 @@ class LastLevelMaxPool(nn.Module):
self.in_feature = "p5" self.in_feature = "p5"
def forward(self, x): def forward(self, x):
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] return [nn.functional.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
class LastLevelP6P7(nn.Module): class LastLevelP6P7(nn.Module):
...@@ -769,7 +768,7 @@ class LastLevelP6P7(nn.Module): ...@@ -769,7 +768,7 @@ class LastLevelP6P7(nn.Module):
def forward(self, c5): def forward(self, c5):
p6 = self.p6(c5) p6 = self.p6(c5)
p7 = self.p7(F.relu(p6)) p7 = self.p7(nn.functional.relu(p6))
return [p6, p7] return [p6, p7]
...@@ -790,11 +789,11 @@ class BasicStem(nn.Module): ...@@ -790,11 +789,11 @@ class BasicStem(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = F.relu_(x) x = nn.functional.relu_(x)
if self.caffe_maxpool: if self.caffe_maxpool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True) x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
else: else:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x return x
@property @property
...@@ -881,10 +880,10 @@ class BottleneckBlock(ResNetBlockBase): ...@@ -881,10 +880,10 @@ class BottleneckBlock(ResNetBlockBase):
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
out = F.relu_(out) out = nn.functional.relu_(out)
out = self.conv2(out) out = self.conv2(out)
out = F.relu_(out) out = nn.functional.relu_(out)
out = self.conv3(out) out = self.conv3(out)
...@@ -894,7 +893,7 @@ class BottleneckBlock(ResNetBlockBase): ...@@ -894,7 +893,7 @@ class BottleneckBlock(ResNetBlockBase):
shortcut = x shortcut = x
out += shortcut out += shortcut
out = F.relu_(out) out = nn.functional.relu_(out)
return out return out
...@@ -1159,7 +1158,7 @@ class ROIOutputs(object): ...@@ -1159,7 +1158,7 @@ class ROIOutputs(object):
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0) return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
def _predict_objs(self, obj_logits, preds_per_image): def _predict_objs(self, obj_logits, preds_per_image):
probs = F.softmax(obj_logits, dim=-1) probs = nn.functional.softmax(obj_logits, dim=-1)
probs = probs.split(preds_per_image, dim=0) probs = probs.split(preds_per_image, dim=0)
return probs return probs
...@@ -1490,7 +1489,7 @@ class RPNHead(nn.Module): ...@@ -1490,7 +1489,7 @@ class RPNHead(nn.Module):
pred_objectness_logits = [] pred_objectness_logits = []
pred_anchor_deltas = [] pred_anchor_deltas = []
for x in features: for x in features:
t = F.relu(self.conv(x)) t = nn.functional.relu(self.conv(x))
pred_objectness_logits.append(self.objectness_logits(t)) pred_objectness_logits.append(self.objectness_logits(t))
pred_anchor_deltas.append(self.anchor_deltas(t)) pred_anchor_deltas.append(self.anchor_deltas(t))
return pred_objectness_logits, pred_anchor_deltas return pred_objectness_logits, pred_anchor_deltas
...@@ -1650,7 +1649,7 @@ class FastRCNNOutputLayers(nn.Module): ...@@ -1650,7 +1649,7 @@ class FastRCNNOutputLayers(nn.Module):
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256] cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304] roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
roi_features = self.fc_attr(roi_features) roi_features = self.fc_attr(roi_features)
roi_features = F.relu(roi_features) roi_features = nn.functional.relu(roi_features)
attr_scores = self.attr_score(roi_features) attr_scores = self.attr_score(roi_features)
return scores, attr_scores, proposal_deltas return scores, attr_scores, proposal_deltas
else: else:
......
...@@ -20,8 +20,8 @@ from typing import Tuple ...@@ -20,8 +20,8 @@ from typing import Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import nn
from utils import img_tensorize from utils import img_tensorize
...@@ -63,7 +63,9 @@ class ResizeShortestEdge: ...@@ -63,7 +63,9 @@ class ResizeShortestEdge:
img = np.asarray(pil_image) img = np.asarray(pil_image)
else: else:
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0) img = nn.functional.interpolate(
img, (newh, neww), mode=self.interp_method, align_corners=False
).squeeze(0)
img_augs.append(img) img_augs.append(img)
return img_augs return img_augs
...@@ -85,7 +87,7 @@ class Preprocess: ...@@ -85,7 +87,7 @@ class Preprocess:
max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
image_sizes = [im.shape[-2:] for im in images] image_sizes = [im.shape[-2:] for im in images]
images = [ images = [
F.pad( nn.functional.pad(
im, im,
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]], [0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
value=self.pad_value, value=self.pad_value,
......
...@@ -25,8 +25,8 @@ import random ...@@ -25,8 +25,8 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -107,11 +107,11 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -107,11 +107,11 @@ def train(args, train_dataset, model, tokenizer, criterion):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
...@@ -166,9 +166,9 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -166,9 +166,9 @@ def train(args, train_dataset, model, tokenizer, criterion):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
...@@ -248,8 +248,8 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -248,8 +248,8 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
) )
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
...@@ -19,10 +19,10 @@ import os ...@@ -19,10 +19,10 @@ import os
from collections import Counter from collections import Counter
import torch import torch
import torch.nn as nn
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image from PIL import Image
from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
......
...@@ -75,7 +75,7 @@ ...@@ -75,7 +75,7 @@
"quantized_model = torch.quantization.quantize_dynamic(\n", "quantized_model = torch.quantization.quantize_dynamic(\n",
" model=model,\n", " model=model,\n",
" qconfig_spec = {\n", " qconfig_spec = {\n",
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n", " nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" },\n", " },\n",
" dtype=torch.qint8,\n", " dtype=torch.qint8,\n",
" )\n", " )\n",
......
...@@ -23,7 +23,6 @@ import math ...@@ -23,7 +23,6 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn import init from torch.nn import init
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
...@@ -104,4 +103,4 @@ class MaskedLinear(nn.Linear): ...@@ -104,4 +103,4 @@ class MaskedLinear(nn.Linear):
# Mask weights with computed mask # Mask weights with computed mask
weight_thresholded = mask * self.weight weight_thresholded = mask * self.weight
# Compute output (linear layer) with masked weights # Compute output (linear layer) with masked weights
return F.linear(input, weight_thresholded, self.bias) return nn.functional.linear(input, weight_thresholded, self.bias)
...@@ -24,8 +24,7 @@ import random ...@@ -24,8 +24,7 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -168,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -168,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
...@@ -287,9 +286,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -287,9 +286,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
) )
loss_logits = ( loss_logits = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
target=F.softmax(logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
...@@ -320,9 +319,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -320,9 +319,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
and (step + 1) == len(epoch_iterator) and (step + 1) == len(epoch_iterator)
): ):
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step) tb_writer.add_scalar("threshold", threshold, global_step)
...@@ -436,8 +435,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -436,8 +435,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
...@@ -25,8 +25,7 @@ import timeit ...@@ -25,8 +25,7 @@ import timeit
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
...@@ -176,11 +175,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -176,11 +175,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
...@@ -308,17 +307,17 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -308,17 +307,17 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
) )
loss_start = ( loss_start = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
target=F.softmax(start_logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
loss_end = ( loss_end = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
target=F.softmax(end_logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
...@@ -346,9 +345,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -346,9 +345,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step) tb_writer.add_scalar("threshold", threshold, global_step)
...@@ -454,8 +453,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -454,8 +453,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
......
import torch from torch import nn
class ClassificationHead(torch.nn.Module): class ClassificationHead(nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size): def __init__(self, class_size, embed_size):
super().__init__() super().__init__()
self.class_size = class_size self.class_size = class_size
self.embed_size = embed_size self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size) # self.mlp1 = nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size)) # self.mlp2 = (nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size) self.mlp = nn.Linear(embed_size, class_size)
def forward(self, hidden_state): def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state)) # hidden_state = nn.functional.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state) # hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state) logits = self.mlp(hidden_state)
return logits return logits
...@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union ...@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F from torch import nn
from tqdm import trange from tqdm import trange
from pplm_classification_head import ClassificationHead from pplm_classification_head import ClassificationHead
...@@ -160,7 +160,7 @@ def perturb_past( ...@@ -160,7 +160,7 @@ def perturb_past(
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :] logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1) probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0 loss = 0.0
loss_list = [] loss_list = []
...@@ -173,7 +173,7 @@ def perturb_past( ...@@ -173,7 +173,7 @@ def perturb_past(
print(" pplm_bow_loss:", loss.data.cpu().numpy()) print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3: if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth) # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1) curr_probs = torch.unsqueeze(probs, dim=1)
...@@ -195,7 +195,7 @@ def perturb_past( ...@@ -195,7 +195,7 @@ def perturb_past(
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach() corrected_probs = probs + correction.detach()
...@@ -527,10 +527,10 @@ def generate_text_pplm( ...@@ -527,10 +527,10 @@ def generate_text_pplm(
else: else:
pert_logits[0, token_idx] /= repetition_penalty pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1) pert_probs = nn.functional.softmax(pert_logits, dim=-1)
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, dtype=torch.long) label = torch.tensor([class_label], device=device, dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label) unpert_discrim_loss = ce_loss(prediction, label)
...@@ -541,7 +541,7 @@ def generate_text_pplm( ...@@ -541,7 +541,7 @@ def generate_text_pplm(
# Fuse the modified model and original model # Fuse the modified model and original model
if perturb: if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
...@@ -552,7 +552,7 @@ def generate_text_pplm( ...@@ -552,7 +552,7 @@ def generate_text_pplm(
else: else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1) pert_probs = nn.functional.softmax(pert_logits, dim=-1)
# sample or greedy # sample or greedy
if sample: if sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment