Unverified Commit 35bbe740 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Remove fast param getter from modules (#1291)



* Add fallback for fast param getter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove fast param getter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7fb22c37
...@@ -11,7 +11,7 @@ import socket ...@@ -11,7 +11,7 @@ import socket
import fcntl import fcntl
import struct import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
...@@ -408,15 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -408,15 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, Float8Tensor] = {} self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get
# Names of attributes that can be set quickly (see __setattr__ # Names of attributes that can be set quickly (see __setattr__
# method) # method)
_fast_setattr_names: Set[str] = { _fast_setattr_names: Set[str] = {
......
...@@ -720,7 +720,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -720,7 +720,7 @@ class GroupedLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:
weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8:
weight_tensors = [ weight_tensors = [
......
...@@ -1159,7 +1159,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1159,7 +1159,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [self._fast_get_param(name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
...@@ -1170,9 +1170,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1170,9 +1170,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names])
else: else:
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
...@@ -1206,8 +1206,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1206,8 +1206,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
inp, inp,
self._fast_get_param("layer_norm_weight"), self.layer_norm_weight,
self._fast_get_param("layer_norm_bias"), self.layer_norm_bias,
weight_tensor, weight_tensor,
weight_fp8, weight_fp8,
bias_tensor, bias_tensor,
......
...@@ -1491,10 +1491,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1491,10 +1491,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Get weight tensors # Get weight tensors
fc1_weight = self._fast_get_param("fc1_weight") fc1_weight = self.fc1_weight
fc1_bias = self._fast_get_param("fc1_bias") fc1_bias = self.fc1_bias
fc2_weight = self._fast_get_param("fc2_weight") fc2_weight = self.fc2_weight
fc2_bias = self._fast_get_param("fc2_bias") fc2_bias = self.fc2_bias
if not self.fp8: if not self.fp8:
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8() fc1_weight = fc1_weight.from_float8()
...@@ -1555,8 +1555,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1555,8 +1555,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
inp, inp,
self._fast_get_param("layer_norm_weight"), self.layer_norm_weight,
self._fast_get_param("layer_norm_bias"), self.layer_norm_bias,
fc1_weight, fc1_weight,
fc1_weight_fp8, fc1_weight_fp8,
fc1_bias, fc1_bias,
......
...@@ -950,7 +950,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -950,7 +950,7 @@ class Linear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [self._fast_get_param(name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
...@@ -961,9 +961,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -961,9 +961,9 @@ class Linear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names])
else: else:
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment