Unverified Commit 5d5031e9 authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix zero ddp state dict (#1378)

parent 0c1a16ea
...@@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP): ...@@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP):
module module
""" """
chunks = self.chunk_manager.get_chunks(self.fp32_params) chunks = self.chunk_manager.get_chunks(self.fp32_params)
chunks_orig_device_type = []
for chunk in chunks: for chunk in chunks:
chunks_orig_device_type.append(chunk.device_type)
self.chunk_manager.access_chunk(chunk) self.chunk_manager.access_chunk(chunk)
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None: if p is not None:
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
destination[prefix + name] = rec_p if keep_vars else rec_p.detach() destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
for chunk in chunks: for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks):
self.chunk_manager.release_chunk(chunk) self.chunk_manager.release_chunk(chunk)
if not chunk.is_empty and orig_dvice_type == 'cpu':
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
for name, buf in self.named_buffers(): for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach() destination[prefix + name] = buf if keep_vars else buf.detach()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment