Unverified Commit 273f5ba0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Revert "search buffers for dtype" (#23308)

Revert "search buffers for dtype (#23159)"

This reverts commit ef42c2c4.
parent ba71d9e9
...@@ -207,15 +207,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -207,15 +207,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
return last_dtype return last_dtype
for t in parameter.buffers(): else:
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5 # For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment