Unverified Commit a2789add authored by cyy's avatar cyy Committed by GitHub
Browse files

[Reland] search model buffers for dtype as the last resort (#23319)

search model buffers for dtype as the last resort
parent 3d764fe8
......@@ -207,7 +207,6 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
......@@ -220,9 +219,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
# fallback to buffer dtype
for t in parameter.buffers():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
return last_dtype
def get_state_dict_float_dtype(state_dict):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment