Unverified Commit de0526f6 authored by kewang-xlnx's avatar kewang-xlnx Committed by GitHub
Browse files

[Misc][Quark] Upstream Quark format to VLLM (#10765)


Signed-off-by: default avatarkewang-xlnx <kewang@xilinx.com>
Signed-off-by: default avatarkewang2 <kewang2@amd.com>
Co-authored-by: default avatarkewang2 <kewang2@amd.com>
Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent 5ecf3e0a
...@@ -31,8 +31,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -31,8 +31,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -254,6 +252,7 @@ class Gemma2Model(nn.Module): ...@@ -254,6 +252,7 @@ class Gemma2Model(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
...@@ -329,7 +328,8 @@ class Gemma2Model(nn.Module): ...@@ -329,7 +328,8 @@ class Gemma2Model(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if scale_name := get_compressed_tensors_cache_scale(name): if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for compressed-tensors quantization # Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -313,6 +313,20 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -313,6 +313,20 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -371,6 +369,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -371,6 +369,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = GraniteModel(vllm_config=vllm_config, self.model = GraniteModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -474,12 +473,15 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -474,12 +473,15 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# processed with quantization, LoRA, fine-tuning, etc. # processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name: if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue continue
if scale_name := get_compressed_tensors_cache_scale(name): if (self.quant_config is not None and
# Loading kv cache scales for compressed-tensors quantization (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name) loaded_params.add(scale_name)
continue continue
......
...@@ -38,8 +38,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -38,8 +38,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -306,6 +304,7 @@ class LlamaModel(nn.Module): ...@@ -306,6 +304,7 @@ class LlamaModel(nn.Module):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -396,12 +395,15 @@ class LlamaModel(nn.Module): ...@@ -396,12 +395,15 @@ class LlamaModel(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if scale_name := get_compressed_tensors_cache_scale(name): if (self.quant_config is not None and
# Loading kv cache scales for compressed-tensors quantization (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name) loaded_params.add(scale_name)
continue continue
......
...@@ -347,6 +347,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -347,6 +347,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config, self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -428,6 +429,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -428,6 +429,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -1116,6 +1116,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1116,6 +1116,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles self.max_num_tiles = config.vision_config.max_num_tiles
...@@ -1429,6 +1430,18 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1429,6 +1430,18 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
name = name.replace('patch_embedding.weight', name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight') 'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -405,6 +405,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -405,6 +405,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = NemotronModel(vllm_config=vllm_config, self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -489,6 +490,18 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -489,6 +490,18 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -546,6 +546,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -546,6 +546,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = vllm_config.quant_config
self.model = PhiMoEModel(vllm_config=vllm_config, self.model = PhiMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -623,6 +624,19 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -623,6 +624,19 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -279,6 +279,7 @@ class Qwen2Model(nn.Module): ...@@ -279,6 +279,7 @@ class Qwen2Model(nn.Module):
)) ))
self.config = config self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -364,6 +365,18 @@ class Qwen2Model(nn.Module): ...@@ -364,6 +365,18 @@ class Qwen2Model(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -409,6 +407,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -409,6 +407,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = SolarModel( self.model = SolarModel(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -491,12 +490,15 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -491,12 +490,15 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if scale_name := get_compressed_tensors_cache_scale(name): if (self.quant_config is not None and
# Loading kv cache scales for compressed-tensors quantization (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name) loaded_params.add(scale_name)
continue continue
......
...@@ -56,8 +56,14 @@ class BasevLLMParameter(Parameter): ...@@ -56,8 +56,14 @@ class BasevLLMParameter(Parameter):
def weight_loader(self): def weight_loader(self):
return self._weight_loader return self._weight_loader
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
return (cond1 and cond2)
def _assert_and_load(self, loaded_weight: torch.Tensor): def _assert_and_load(self, loaded_weight: torch.Tensor):
assert self.data.shape == loaded_weight.shape assert (self.data.shape == loaded_weight.shape
or self._is_1d_and_scalar(loaded_weight))
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor): def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
......
...@@ -70,7 +70,7 @@ class RocmPlatform(Platform): ...@@ -70,7 +70,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf" "fbgemm_fp8", "gguf", "quark"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment