Unverified Commit 1ed2ebf6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[style] consistent nn. and nn.functional (#12124)

* consistent nn. and nn.functional

* fix glitch

* fix glitch #2
parent ff7c8168
...@@ -18,6 +18,7 @@ import math ...@@ -18,6 +18,7 @@ import math
from typing import Callable, Iterable, Optional, Tuple, Union from typing import Callable, Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
...@@ -272,7 +273,7 @@ class AdamW(Optimizer): ...@@ -272,7 +273,7 @@ class AdamW(Optimizer):
<https://arxiv.org/abs/1711.05101>`__. <https://arxiv.org/abs/1711.05101>`__.
Parameters: Parameters:
params (:obj:`Iterable[torch.nn.parameter.Parameter]`): params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups. Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`, defaults to 1e-3): lr (:obj:`float`, `optional`, defaults to 1e-3):
The learning rate to use. The learning rate to use.
...@@ -288,7 +289,7 @@ class AdamW(Optimizer): ...@@ -288,7 +289,7 @@ class AdamW(Optimizer):
def __init__( def __init__(
self, self,
params: Iterable[torch.nn.parameter.Parameter], params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3, lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999), betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6, eps: float = 1e-6,
...@@ -379,7 +380,7 @@ class Adafactor(Optimizer): ...@@ -379,7 +380,7 @@ class Adafactor(Optimizer):
`relative_step=False`. `relative_step=False`.
Arguments: Arguments:
params (:obj:`Iterable[torch.nn.parameter.Parameter]`): params (:obj:`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups. Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (:obj:`float`, `optional`): lr (:obj:`float`, `optional`):
The external learning rate. The external learning rate.
......
...@@ -264,7 +264,7 @@ class Trainer: ...@@ -264,7 +264,7 @@ class Trainer:
def __init__( def __init__(
self, self,
model: Union[PreTrainedModel, torch.nn.Module] = None, model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None, args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Dataset] = None,
...@@ -772,7 +772,7 @@ class Trainer: ...@@ -772,7 +772,7 @@ class Trainer:
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
""" """
if self.optimizer is None: if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm]) decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name] decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{ {
...@@ -933,7 +933,7 @@ class Trainer: ...@@ -933,7 +933,7 @@ class Trainer:
# Multi-gpu training (should be after apex fp16 initialization) # Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model # Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways. # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
...@@ -970,7 +970,7 @@ class Trainer: ...@@ -970,7 +970,7 @@ class Trainer:
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
else: else:
find_unused_parameters = True find_unused_parameters = True
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[self.args.local_rank], device_ids=[self.args.local_rank],
output_device=self.args.local_rank, output_device=self.args.local_rank,
...@@ -1288,7 +1288,7 @@ class Trainer: ...@@ -1288,7 +1288,7 @@ class Trainer:
model.clip_grad_norm_(args.max_grad_norm) model.clip_grad_norm_(args.max_grad_norm)
else: else:
# Revert to normal clipping otherwise, handling Apex or full precision # Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(), amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
args.max_grad_norm, args.max_grad_norm,
) )
......
...@@ -28,6 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union ...@@ -28,6 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
...@@ -441,7 +442,7 @@ class LabelSmoother: ...@@ -441,7 +442,7 @@ class LabelSmoother:
def __call__(self, model_output, labels): def __call__(self, model_output, labels):
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) log_probs = -nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1: if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1) labels = labels.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment