Unverified Commit 3c05b9f7 authored by Kimbing Ng's avatar Kimbing Ng Committed by GitHub
Browse files

Fixes #12673. `record_stream` in group offloading is not working properly (#12721)



* Fixes #12673.

    Wrong default_stream is used. leading to wrong execution order when record_steram is enabled.

* update

* Update test

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 9379b239
...@@ -153,27 +153,27 @@ class ModuleGroup: ...@@ -153,27 +153,27 @@ class ModuleGroup:
finally: finally:
pinned_dict = None pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor): def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream: if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream()) tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None): def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source) self._transfer_tensor_to_device(param, source, default_stream)
for buffer in group_module.buffers(): for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source) self._transfer_tensor_to_device(buffer, source, default_stream)
for param in self.parameters: for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source) self._transfer_tensor_to_device(param, source, default_stream)
for buffer in self.buffers: for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source) self._transfer_tensor_to_device(buffer, source, default_stream)
def _onload_from_disk(self): def _onload_from_disk(self):
if self.stream is not None: if self.stream is not None:
...@@ -208,10 +208,12 @@ class ModuleGroup: ...@@ -208,10 +208,12 @@ class ModuleGroup:
self.stream.synchronize() self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
with context: with context:
if self.stream is not None: if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory: with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory) self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
else: else:
self._process_tensors_from_modules(None) self._process_tensors_from_modules(None)
......
...@@ -1814,9 +1814,6 @@ class ModelTesterMixin: ...@@ -1814,9 +1814,6 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
return model(**inputs_dict)[0] return model(**inputs_dict)[0]
if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment