Unverified Commit 58d2ebab authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Deprecate unused APIs (#321)



* Deprecate unused APIs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b172bad8
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
...@@ -538,6 +539,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -538,6 +539,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
r""" r"""
Applies layer normalization followed by linear transformation to the incoming data. Applies layer normalization followed by linear transformation to the incoming data.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases.
Parameters Parameters
---------- ----------
in_features : int in_features : int
...@@ -585,9 +591,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -585,9 +591,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -633,6 +636,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -633,6 +636,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It is ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -695,72 +706,71 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -695,72 +706,71 @@ class LayerNormLinear(TransformerEngineBaseModule):
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
if not skip_weight_param_allocation: self.weight_tensor = torch.empty(
self.weight_tensor = torch.empty( self.out_features, self.in_features,
self.out_features, self.in_features, device=torch.cuda.current_device(),
dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype) dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
initialize_affine_weight_gpu( with torch.no_grad():
self.weight_tensor, self.bias_tensor.zero_()
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias: if parameters_split is None:
self.bias_tensor = torch.empty( parameters_split = ("",)
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
with torch.no_grad(): assert (
self.bias_tensor.zero_() self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
if parameters_split is None: split_size = self.out_features // len(parameters_split)
parameters_split = ("",)
assert ( self.weight_names = []
self.out_features % len(parameters_split) == 0 self.bias_names = []
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split) for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.weight_names = [] self.register_parameter(
self.bias_names = [] wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
for i, pname in enumerate(parameters_split): set_tensor_model_parallel_attributes(
wname = pname + "weight" tensor=getattr(self, wname),
bname = pname + "bias" is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
set_tensor_model_parallel_attributes( if parallel_mode == "column":
tensor=getattr(self, wname), set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname) self.weight_names.append(wname)
self.bias_names.append(bname) self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
...@@ -821,17 +831,15 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -821,17 +831,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases.
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : torch.Tensor
Input tensor. Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -847,16 +855,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -847,16 +855,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases."
)
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = ( bias_tensor = (
bias if bias is not None self.bias if self.parameters_split is None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names)
) )
weight_tensor = ( weight_tensor = (
weight if weight is not None self.weight if self.parameters_split is None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled() else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -441,6 +442,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -441,6 +442,11 @@ class Linear(TransformerEngineBaseModule):
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases.
Parameters Parameters
---------- ----------
in_features : int in_features : int
...@@ -474,9 +480,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -474,9 +480,6 @@ class Linear(TransformerEngineBaseModule):
used to decide whether this Linear layer is Column Parallel Linear or Row used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_. Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed. When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -518,6 +521,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -518,6 +521,14 @@ class Linear(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It has ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -558,72 +569,71 @@ class Linear(TransformerEngineBaseModule): ...@@ -558,72 +569,71 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation: self.weight_tensor = torch.empty(
self.weight_tensor = torch.empty( self.out_features, self.in_features,
self.out_features, self.in_features, device=torch.cuda.current_device(),
dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype) dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
initialize_affine_weight_gpu( with torch.no_grad():
self.weight_tensor, self.bias_tensor.zero_()
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias: if parameters_split is None:
self.bias_tensor = torch.empty( parameters_split = ("",)
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
with torch.no_grad(): assert (
self.bias_tensor.zero_() self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
if parameters_split is None: split_size = self.out_features // len(parameters_split)
parameters_split = ("",)
assert ( self.weight_names = []
self.out_features % len(parameters_split) == 0 self.bias_names = []
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split) for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.weight_names = [] self.register_parameter(
self.bias_names = [] wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
for i, pname in enumerate(parameters_split): set_tensor_model_parallel_attributes(
wname = pname + "weight" tensor=getattr(self, wname),
bname = pname + "bias" is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
set_tensor_model_parallel_attributes( if parallel_mode == "column":
tensor=getattr(self, wname), set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname) self.weight_names.append(wname)
self.bias_names.append(bname) self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
...@@ -668,17 +678,15 @@ class Linear(TransformerEngineBaseModule): ...@@ -668,17 +678,15 @@ class Linear(TransformerEngineBaseModule):
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases.
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : torch.Tensor
Input tensor. Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -694,16 +702,20 @@ class Linear(TransformerEngineBaseModule): ...@@ -694,16 +702,20 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases."
)
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = ( bias_tensor = (
bias if bias is not None self.bias if self.parameters_split is None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names) else self.noop_cat("bias_tensor", self.bias_names)
) )
weight_tensor = ( weight_tensor = (
weight if weight is not None self.weight if self.parameters_split is None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled() else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment