"examples/seq2seq/bertabs/__init__.py" did not exist on "81d6841b4be25a164235975e5ebdcf99d7a26633"
Unverified Commit ef42c2c4 authored by cyy's avatar cyy Committed by GitHub
Browse files

search buffers for dtype (#23159)

parent 312b104f
......@@ -207,7 +207,15 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
for t in parameter.buffers():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment