Unverified Commit 34e5e11f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny make device_loading_context more static (#9478)

parent 2600fc0d
...@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) ...@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
yield module yield module
return return
original_device_states: Dict[str, torch.device] = {} original_infos: Dict[str, Dict] = {}
# Store original device states and move parameters to GPU if they're on CPU # Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters(): for name, p in module.named_parameters():
if p.device.type == "cpu": if p.device.type == "cpu":
original_device_states[name] = p.device original_data = p.data
p.data = p.data.to(target_device) device_data = p.data.to(target_device)
original_infos[name] = dict(
device=p.device,
original_data=original_data,
device_data=device_data,
)
p.data = device_data
# Parameters already on target device are not touched # Parameters already on target device are not touched
try: try:
...@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device) ...@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
# Restore parameters to their original devices, ignoring new parameters # Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
for name, p in module.named_parameters(): for name, p in module.named_parameters():
if name in original_device_states: if name in original_infos:
original_device: torch.device = original_device_states[name] original_info = original_infos[name]
if original_device.type == "cpu": device_data = original_info["device_data"]
original_data = original_info["original_data"]
original_device: torch.device = original_info["device"]
if (
(device_data.device == p.data.device)
and (device_data.data_ptr() == p.data.data_ptr())
and (device_data.shape == p.data.shape)
and (device_data.dtype == p.data.dtype)
):
original_data.copy_(p.data.to(original_data.device))
p.data = original_data
elif original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument # `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided( cpu_data = torch.empty_strided(
size=p.data.size(), size=p.data.size(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment