Unverified Commit c54af4c7 authored by haikuoxin's avatar haikuoxin Committed by GitHub
Browse files

Add a condition for nested_detach (#31855)

fix bug: https://github.com/huggingface/transformers/issues/31852
parent 080e14b2
......@@ -192,7 +192,7 @@ def nested_detach(tensors):
return type(tensors)(nested_detach(t) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
return tensors.detach()
return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors
def nested_xla_mesh_reduce(tensors, name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment