"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "651408a077f842e76e75bfc7d02b8ac38eeb6480"
Unverified Commit a2789add authored by cyy's avatar cyy Committed by GitHub
Browse files

[Reland] search model buffers for dtype as the last resort (#23319)

search model buffers for dtype as the last resort
parent 3d764fe8
...@@ -207,22 +207,29 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -207,22 +207,29 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
return last_dtype return last_dtype
else: # For nn.DataParallel compatibility in PyTorch > 1.5
# For nn.DataParallel compatibility in PyTorch > 1.5 def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
gen = parameter._named_members(get_members_fn=find_tensor_attributes) last_tuple = None
last_tuple = None for tuple in gen:
for tuple in gen: last_tuple = tuple
last_tuple = tuple if tuple[1].is_floating_point():
if tuple[1].is_floating_point(): return tuple[1].dtype
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype # fallback to the last dtype
return last_tuple[1].dtype return last_tuple[1].dtype
# fallback to buffer dtype
for t in parameter.buffers():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
return last_dtype
def get_state_dict_float_dtype(state_dict): def get_state_dict_float_dtype(state_dict):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment