Unverified Commit a2789add authored by cyy's avatar cyy Committed by GitHub
Browse files

[Reland] search model buffers for dtype as the last resort (#23319)

search model buffers for dtype as the last resort
parent 3d764fe8
...@@ -207,7 +207,6 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -207,7 +207,6 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
return last_dtype return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5 # For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
...@@ -220,9 +219,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -220,9 +219,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
if tuple[1].is_floating_point(): if tuple[1].is_floating_point():
return tuple[1].dtype return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype # fallback to the last dtype
return last_tuple[1].dtype return last_tuple[1].dtype
# fallback to buffer dtype
for t in parameter.buffers():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
return last_dtype
def get_state_dict_float_dtype(state_dict): def get_state_dict_float_dtype(state_dict):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment