Unverified Commit 4b919657 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Factor out methods (#10215)

parent e94d63f6
...@@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices( ...@@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices(
return heads, index return heads, index
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
class ModuleUtilsMixin: class ModuleUtilsMixin:
""" """
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
...@@ -145,36 +175,14 @@ class ModuleUtilsMixin: ...@@ -145,36 +175,14 @@ class ModuleUtilsMixin:
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device). device).
""" """
try: return get_parameter_device(self)
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property @property
def dtype(self) -> dtype: def dtype(self) -> dtype:
""" """
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
""" """
try: return get_parameter_dtype(self)
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
""" """
...@@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module): ...@@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1) x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None: if p_mask is not None:
if next(self.parameters()).dtype == torch.float16: if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask x = x * (1 - p_mask) - 65500 * p_mask
else: else:
x = x * (1 - p_mask) - 1e30 * p_mask x = x * (1 - p_mask) - 1e30 * p_mask
...@@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module): ...@@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module):
x = self.dense_1(x).squeeze(-1) x = self.dense_1(x).squeeze(-1)
if p_mask is not None: if p_mask is not None:
if next(self.parameters()).dtype == torch.float16: if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask x = x * (1 - p_mask) - 65500 * p_mask
else: else:
x = x * (1 - p_mask) - 1e30 * p_mask x = x * (1 - p_mask) - 1e30 * p_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment