Unverified Commit fef44233 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

FSDP: better traceback for dtype assertion (#912)

parent 6b2f992c
......@@ -279,8 +279,14 @@ class FlattenParamsWrapper(nn.Module):
shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {}
shared_param_infos = []
params = []
fp32 = []
fp16 = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p.dtype != torch.float16:
fp32.append(module_name)
else:
fp16.append(module_name)
if p is not None and (m, n) in p_set:
if p in shared_param_memo:
mname, shared_m, shared_n = shared_param_memo[p]
......@@ -290,8 +296,10 @@ class FlattenParamsWrapper(nn.Module):
param_infos.append((module_name, m, n))
params.append(p)
del shared_param_memo
assert len(set(p.dtype for p in params)) == 1, "expects all parameters to have same dtype"
fp16_msg, fp32_msg = ",".join(fp16), ",".join(fp32)
assert (
len(set(p.dtype for p in params)) == 1
), f"expects all parameters to have same dtype: fp32: {fp32_msg} \n fp16: {fp16_msg} "
assert len(set(p.requires_grad for p in params)) == 1, "expects all parameters to have same requires_grad"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment