Unverified Commit 2660b928 authored by sfbemerk's avatar sfbemerk Committed by GitHub
Browse files

Bugfix for offloading+prefetch for GLM-4.7-FP8 (#37178)


Signed-off-by: default avatarBenjamin Merkel <benjamin.merkel@tngtech.com>
Co-authored-by: default avatarBenjamin Merkel <benjamin.merkel@tngtech.com>
parent 293f036e
...@@ -431,10 +431,32 @@ class _ModuleOffloader: ...@@ -431,10 +431,32 @@ class _ModuleOffloader:
Called after process_weights_after_loading to ensure _cpu_storage Called after process_weights_after_loading to ensure _cpu_storage
contains the final processed weights, not stale pre-loading data. contains the final processed weights, not stale pre-loading data.
Parameters whose underlying nn.Parameter was deleted by
process_weights_after_loading (e.g. transient KV-cache scale params)
are pruned from self._param_offloaders so they do not participate in
buffer-pool allocation or prefetching.
""" """
for param_offloader in self._param_offloaders.values(): for param_offloader in self._param_offloaders.values():
param_offloader.sync_cpu_storage() param_offloader.sync_cpu_storage()
# Remove offloaders whose parameter was deleted during
# process_weights_after_loading (e.g. k_scale / v_scale).
deleted = [
name
for name, offloader in self._param_offloaders.items()
if getattr(offloader, "_param_deleted", False)
]
if deleted:
logger.debug(
"Pruning %d transient offloaded param(s) that were deleted "
"by process_weights_after_loading: %s",
len(deleted),
deleted,
)
for name in deleted:
del self._param_offloaders[name]
def get_param_infos(self) -> list[ParamInfo]: def get_param_infos(self) -> list[ParamInfo]:
"""Get parameter metadata for buffer pool allocation. """Get parameter metadata for buffer pool allocation.
...@@ -590,6 +612,11 @@ class _CpuParamOffloader(_BaseParamOffloader): ...@@ -590,6 +612,11 @@ class _CpuParamOffloader(_BaseParamOffloader):
super().__init__(module, param_name) super().__init__(module, param_name)
self._cpu_storage: torch.Tensor | None = None self._cpu_storage: torch.Tensor | None = None
self._gpu_buffer: torch.Tensor | None = None # Store reference to GPU buffer self._gpu_buffer: torch.Tensor | None = None # Store reference to GPU buffer
# Set to True if the underlying nn.Parameter was deleted by
# process_weights_after_loading (e.g. transient KV-cache scale params
# such as k_scale/v_scale created by BaseKVCacheMethod.create_weights
# and deleted after copying into permanent _k_scale buffers).
self._param_deleted: bool = False
# Offload to CPU immediately to free GPU memory during model loading # Offload to CPU immediately to free GPU memory during model loading
self._offload_to_cpu_internal() self._offload_to_cpu_internal()
...@@ -696,8 +723,22 @@ class _CpuParamOffloader(_BaseParamOffloader): ...@@ -696,8 +723,22 @@ class _CpuParamOffloader(_BaseParamOffloader):
1. process_weights_after_loading may transform weights (quantization) 1. process_weights_after_loading may transform weights (quantization)
2. device_loading_context creates NEW CPU tensors when moving back 2. device_loading_context creates NEW CPU tensors when moving back
3. Our old _cpu_storage would have pre-processed or stale data 3. Our old _cpu_storage would have pre-processed or stale data
If the parameter no longer exists on the module (e.g. transient
KV-cache scale parameters such as k_scale/v_scale that are created
by BaseKVCacheMethod.create_weights() and then deleted by
process_weights_after_loading() after copying their values into
permanent _k_scale buffers), the offloader marks itself as deleted
and skips the sync. The caller (_ModuleOffloader.sync_cpu_storage)
is responsible for removing these stale entries.
""" """
self._update_cpu_storage_from_param() try:
self._update_cpu_storage_from_param()
except AttributeError:
# The parameter was deleted by process_weights_after_loading.
# Drop the now-stale CPU storage so this offloader can be pruned.
self._param_deleted = True
self._cpu_storage = None
def post_init(self): def post_init(self):
"""No-op: offloading done in offload_to_cpu/assign_static_buffer.""" """No-op: offloading done in offload_to_cpu/assign_static_buffer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment