Unverified Commit 6450c122 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny refactor weight loading logic (#5232)

parent b6cf3532
......@@ -557,12 +557,7 @@ class ModelRunner:
return iter
def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
return model
with set_default_torch_dtype(self.model_config.dtype):
......
......@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config,
)
model.load_weights(self._get_all_weights(model_config, model))
self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()
@staticmethod
def load_weights_and_postprocess(model, weights, target_device):
model.load_weights(weights)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
class LayeredModelLoader(DefaultModelLoader):
"""Model loader that loads weights layer by layer so that one can quantize a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment