Unverified Commit 1faa8cb7 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Quantization] - Added uses_meta_device_weights to quant config (#34645)


Signed-off-by: default avatarJosephasafg <ajgard7@gmail.com>
parent e89a91d9
...@@ -18,6 +18,11 @@ else: ...@@ -18,6 +18,11 @@ else:
class QuantizeMethodBase(ABC): class QuantizeMethodBase(ABC):
"""Base class for different quantized methods.""" """Base class for different quantized methods."""
# Whether this method creates weights on meta device for online quantization.
# When True, weights are created on meta device and quantized layer-wise
# in process_weights_after_loading, reducing peak memory during loading.
uses_meta_device: bool = False
@abstractmethod @abstractmethod
def create_weights( def create_weights(
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
......
...@@ -527,6 +527,8 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -527,6 +527,8 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading.""" and quantized the weights during loading."""
uses_meta_device: bool = True
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1039,6 +1041,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1039,6 +1041,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
quant_config: The quantization config. quant_config: The quantization config.
""" """
uses_meta_device: bool = True
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer) super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized assert not quant_config.is_checkpoint_fp8_serialized
......
...@@ -1092,16 +1092,20 @@ def initialize_dummy_weights( ...@@ -1092,16 +1092,20 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type. the parameter's number of elements and its data type.
""" """
# TODO(future PR): make the check below more generic as more online
# quant backends are added # Check if any module uses online quantization with meta device weights.
is_fp8_py_quant = model_config.quantization == "fp8" # If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def uses_meta_device(module: torch.nn.Module) -> bool:
quant_method = getattr(module, "quant_method", None)
return getattr(quant_method, "uses_meta_device", False)
has_online_quant = any(uses_meta_device(m) for m in model.modules())
for param in model.state_dict().values(): for param in model.state_dict().values():
if is_fp8_py_quant and param.device == torch.device("meta"): if has_online_quant and param.device == torch.device("meta"):
# for fp8.py's online quantization, dummy weight init will happen # For online quantization, weights are created on meta device and
# in `process_weights_after_loading`. # dummy weight init will happen in `process_weights_after_loading`.
# TODO(future PR): consider refactoring dummy model init to compose
# better with online quantization
continue continue
initialize_single_dummy_weight(param, low, high, seed) initialize_single_dummy_weight(param, low, high, seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment