Unverified Commit 35bbe740 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Remove fast param getter from modules (#1291)



* Add fallback for fast param getter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove fast param getter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7fb22c37
......@@ -11,7 +11,7 @@ import socket
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
import torch
......@@ -408,15 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
......
......@@ -720,7 +720,7 @@ class GroupedLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:
weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
......
......@@ -1159,7 +1159,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
......@@ -1170,9 +1170,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Initialize FP8 weights if needed
weight_fp8 = None
......@@ -1206,8 +1206,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
args = [None]
args += (
inp,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
weight_fp8,
bias_tensor,
......
......@@ -1491,10 +1491,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Get weight tensors
fc1_weight = self._fast_get_param("fc1_weight")
fc1_bias = self._fast_get_param("fc1_bias")
fc2_weight = self._fast_get_param("fc2_weight")
fc2_bias = self._fast_get_param("fc2_bias")
fc1_weight = self.fc1_weight
fc1_bias = self.fc1_bias
fc2_weight = self.fc2_weight
fc2_bias = self.fc2_bias
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8()
......@@ -1555,8 +1555,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
args = [None]
args += (
inp,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_weight_fp8,
fc1_bias,
......
......@@ -950,7 +950,7 @@ class Linear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
......@@ -961,9 +961,9 @@ class Linear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names])
else:
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Initialize FP8 weights if needed
weight_fp8 = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment