Unverified Commit 30b404ce authored by Jerry Zhang's avatar Jerry Zhang Committed by GitHub
Browse files

Add torchao quant for mixtral and qwen_moe (#1418)

parent 70b68029
...@@ -2,10 +2,20 @@ ...@@ -2,10 +2,20 @@
Common utilities for torchao. Common utilities for torchao.
""" """
from typing import Dict, Set
import torch import torch
def torchao_quantize_param_data(param, torchao_config): def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings # Lazy import to suppress some warnings
from torchao.quantization import ( from torchao.quantization import (
int4_weight_only, int4_weight_only,
...@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config): ...@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only()) quantize_(dummy_linear, float8_weight_only())
return dummy_linear.weight return dummy_linear.weight
def apply_torchao_config_(
self: torch.nn.Module,
params_dict: Dict[str, torch.Tensor],
param_suffixes: Set[str],
) -> None:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if self.torchao_config:
for param_suffix in param_suffixes:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
...@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.torchao_config: apply_torchao_config_(self, params_dict, set(["proj.weight"]))
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
if param_suffix in name:
param = params_dict[name]
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
......
...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
EntryClass = MixtralForCausalLM EntryClass = MixtralForCausalLM
...@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, cache_config, quant_config) self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
...@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment