"examples/trials/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bc0e55a00bbbc825f27d851ccc58a749d18b4fd9"
Unverified Commit ef42c2c4 authored by cyy's avatar cyy Committed by GitHub
Browse files

search buffers for dtype (#23159)

parent 312b104f
...@@ -207,21 +207,29 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -207,21 +207,29 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
return last_dtype return last_dtype
else: for t in parameter.buffers():
# For nn.DataParallel compatibility in PyTorch > 1.5 last_dtype = t.dtype
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: if t.is_floating_point():
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return t.dtype
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes) if last_dtype is not None:
last_tuple = None # if no floating dtype was found return whatever the first dtype is
for tuple in gen: return last_dtype
last_tuple = tuple
if tuple[1].is_floating_point(): # For nn.DataParallel compatibility in PyTorch > 1.5
return tuple[1].dtype def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
# fallback to the last dtype return tuples
return last_tuple[1].dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to the last dtype
return last_tuple[1].dtype
def get_state_dict_float_dtype(state_dict): def get_state_dict_float_dtype(state_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment