"sgl-kernel/python/sgl_kernel/hadamard.py" did not exist on "868403f6425bc7f237d82fe6469ff5a58055c4c5"
Unverified Commit 34e5e11f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny make device_loading_context more static (#9478)

parent 2600fc0d
......@@ -79,13 +79,19 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
yield module
return
original_device_states: Dict[str, torch.device] = {}
original_infos: Dict[str, Dict] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
original_data = p.data
device_data = p.data.to(target_device)
original_infos[name] = dict(
device=p.device,
original_data=original_data,
device_data=device_data,
)
p.data = device_data
# Parameters already on target device are not touched
try:
......@@ -95,9 +101,21 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
if name in original_infos:
original_info = original_infos[name]
device_data = original_info["device_data"]
original_data = original_info["original_data"]
original_device: torch.device = original_info["device"]
if (
(device_data.device == p.data.device)
and (device_data.data_ptr() == p.data.data_ptr())
and (device_data.shape == p.data.shape)
and (device_data.dtype == p.data.dtype)
):
original_data.copy_(p.data.to(original_data.device))
p.data = original_data
elif original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment