Unverified Commit 862bcff8 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Support loading of larger models with on-the-fly quantization (#3061)

parent 8b84e69f
......@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
@dataclass
......
......@@ -5,6 +5,7 @@ Common utilities for torchao.
import logging
import os
import pwd
from typing import Callable, Optional
import torch
......@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
return True
def proj_filter(
module: torch.nn.Module,
fqn: str,
):
"""Filter function for quantizing projection layers."""
return "proj" in fqn
def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None
model: torch.nn.Module,
torchao_config: str,
filter_fn: Optional[Callable] = proj_filter,
):
"""Quantize a modelwith torchao quantization specified by torchao_config
......@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
)
from torchao.quantization.observer import PerRow, PerTensor
if filter_fn is None:
def filter_fn(module, fqn):
return "proj" in fqn
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
......
......@@ -185,9 +185,12 @@ class ModelRunner:
self.load_model()
# Apply torchao quantization
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
......
......@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
return model.eval()
class LayeredModelLoader(DefaultModelLoader):
"""Model loader that loads weights layer by layer so that one can quantize a
layer before loading another to make the peak memory envelope smaller."""
def __init__(self, load_config: LoadConfig):
# Back to the default load format
load_config.load_format = LoadFormat.AUTO
super().__init__(load_config)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.managers.schedule_batch import global_server_args_dict
torchao_config = global_server_args_dict.get("torchao_config")
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
# Create model on meta device
with torch.device("meta"):
model = _initialize_model(
model_config,
self.load_config,
)
# Check model's layered load support
if not hasattr(model, "load_weights_to_module"):
raise ValueError(
"LayeredModelLoader requires the model to have a "
"`load_weights_to_module` method. "
f"{model_config.model_path} does not support it."
)
# Get all weights from disk
weights = self._get_all_weights(model_config, model)
# Helper function to recursively fill the weights of a module
def fill_module(module, fqn: List[str], weights):
"""
fqn: list of strings representing the fully qualified name of `module`.
"""
# Layer by layer
for name, submod in module.named_children():
fill_module(submod, fqn + [name], weights)
# First materialize on target device
module.to_empty(device=target_device, recurse=False)
fqn_path = ".".join(fqn)
# Fill weights
model.load_weights_to_module(
fqn_path,
weights,
)
# Quantize weights if applicable
if torchao_config and "proj" in fqn_path:
# Note: `None` here is needed to indicate no filter, see
# `apply_torchao_config_to_model` for details.
apply_torchao_config_to_model(module, torchao_config, None)
# Start calling on root module
fill_module(model, [], weights)
if torchao_config:
model.torchao_applied = True
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
......@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.LAYERED:
return LayeredModelLoader(load_config)
return DefaultModelLoader(load_config)
......@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights_to_module(
self,
fqn: str,
weights: Iterable[Tuple[str, torch.Tensor]],
):
"""Load weights onto submodule pointed by path `fqn`."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
module = self.get_submodule(fqn)
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
......@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
......@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
"""Load weights onto the full model."""
self.load_weights_to_module("", weights)
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
......
......@@ -317,6 +317,7 @@ class ServerArgs:
"dummy",
"gguf",
"bitsandbytes",
"layered",
],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
......@@ -330,7 +331,10 @@ class ServerArgs:
"which is mainly for profiling."
'"gguf" will load the weights in the gguf format. '
'"bitsandbytes" will load the weights using bitsandbytes '
"quantization.",
"quantization."
'"layered" loads weights layer by layer so that one can quantize a '
"layer before loading another to make the peak memory envelope "
"smaller.",
)
parser.add_argument(
"--trust-remote-code",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment