"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "c3407300eb01be249c428148442c8b71b33da659"
Unverified Commit ccbc8cf4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Register weight and bias params in linear op (#2027)



* Register weight/bias params in linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak docs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure linear op checkpoint is backward-compatible
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Check for invalid case before setting bias
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c582f6be
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -91,6 +91,8 @@ class Linear(FusedOperation): ...@@ -91,6 +91,8 @@ class Linear(FusedOperation):
# Construct basic ops # Construct basic ops
ops = [] ops = []
linear_idx = None
bias_idx = None
linear_kwargs = { linear_kwargs = {
"in_features": in_features, "in_features": in_features,
"out_features": out_features, "out_features": out_features,
...@@ -111,14 +113,16 @@ class Linear(FusedOperation): ...@@ -111,14 +113,16 @@ class Linear(FusedOperation):
} }
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction # Row TP: GEMM + bias + reduction
linear_idx = len(ops)
linear_kwargs["in_features"] = local_in_features linear_kwargs["in_features"] = local_in_features
linear_kwargs["out_features"] = local_out_features linear_kwargs["out_features"] = local_out_features
linear_kwargs["tensor_parallel_mode"] = None linear_kwargs["tensor_parallel_mode"] = None
linear_kwargs["tensor_parallel_group"] = None linear_kwargs["tensor_parallel_group"] = None
linear_kwargs["sequence_parallel"] = False linear_kwargs["sequence_parallel"] = False
bias_kwargs["size"] *= tensor_parallel_size
ops.append(BasicLinear(**linear_kwargs)) ops.append(BasicLinear(**linear_kwargs))
if bias: if bias:
bias_idx = len(ops)
bias_kwargs["size"] *= tensor_parallel_size
ops.append(Bias(**bias_kwargs)) ops.append(Bias(**bias_kwargs))
if sequence_parallel: if sequence_parallel:
ops.append(ReduceScatter(tensor_parallel_group)) ops.append(ReduceScatter(tensor_parallel_group))
...@@ -126,45 +130,81 @@ class Linear(FusedOperation): ...@@ -126,45 +130,81 @@ class Linear(FusedOperation):
ops.append(AllReduce(tensor_parallel_group)) ops.append(AllReduce(tensor_parallel_group))
else: else:
# Column TP or no TP: (gather + GEMM) + bias # Column TP or no TP: (gather + GEMM) + bias
linear_idx = len(ops)
ops.append(BasicLinear(**linear_kwargs)) ops.append(BasicLinear(**linear_kwargs))
if bias: if bias:
bias_idx = len(ops)
ops.append(Bias(**bias_kwargs)) ops.append(Bias(**bias_kwargs))
# Initialize base class # Initialize base class
super().__init__(ops) super().__init__(ops)
self._has_bias: bool = bias # Register parameters
self._linear_idx: Optional[int] = linear_idx
self._bias_idx: Optional[int] = bias_idx
self.register_parameter("weight", self.basic_ops[self._linear_idx].weight)
bias = None
if self._bias_idx is not None:
bias = self.basic_ops[self._bias_idx].bias
self.register_parameter("bias", bias)
@property def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
def weight(self) -> torch.nn.Parameter: """Add a parameter to the module
"""Weight tensor
Parameter is owned by `BasicLinear` operation. Also updates the basic operation that owns the parameter.
""" """
return self.basic_ops[0].weight if name == "bias" and self._bias_idx is None and param is not None:
@weight.setter
def weight(self, value: Optional[torch.nn.Parameter]) -> None:
self.basic_ops[0].weight = value
@property
def bias(self) -> Optional[torch.nn.Parameter]:
"""Bias tensor
Parameter is owned by `Bias` operation.
"""
if self._has_bias:
return self.basic_ops[1].bias
return None
@bias.setter
def bias(self, value: Optional[torch.nn.Parameter]) -> None:
if self._has_bias:
self.basic_ops[1].bias = value
elif value is not None:
raise ValueError( raise ValueError(
"Attempted to set bias parameter in Linear operation " "Attempted to set bias parameter in Linear operation "
"that does not have bias enabled" "that does not have bias enabled"
) )
super().register_parameter(name, param)
if name == "weight":
self.basic_ops[self._linear_idx].weight = param
elif name == "bias" and self._bias_idx is not None:
self.basic_ops[self._bias_idx].bias = param
def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]:
"""Save state"""
state_dict = super().state_dict(prefix=prefix, **kwargs)
# Remove basic op params from state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" in state_dict:
del state_dict[f"{prefix}weight"]
if f"{prefix}bias" in state_dict:
del state_dict[f"{prefix}bias"]
return state_dict
def _load_from_state_dict(
self,
state_dict: dict[str, Any],
prefix: str,
*args,
**kwargs,
) -> None:
# Add basic op params to state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" not in state_dict:
state_dict[f"{prefix}weight"] = state_dict[
f"{prefix}basic_ops.{self._linear_idx}.weight"
]
if f"{prefix}bias" not in state_dict:
if self._bias_idx is None:
state_dict[f"{prefix}bias"] = None
else:
state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"]
# Load state dict
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment