Unverified Commit a62aaf1d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

parent 603ad848
......@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
self,
config: Qwen2Config,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
linear_method=linear_method,
quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, linear_method)
Qwen2DecoderLayer(config, layer_idx, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
......
......@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
......@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
linear_method=None)
quant_config=None)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False,
)
else:
......@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
if (config.num_experts is not None
and (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
linear_method=linear_method)
quant_config=quant_config)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config,
layer_idx,
linear_method=linear_method)
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = Qwen2MoeModel(config, linear_method)
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -28,11 +28,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
......@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
......@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
......@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
linear_method=linear_method)
quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
......@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
......@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method)
StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
......@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
......@@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
linear_method=linear_method)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -202,7 +202,7 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method)
Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
......@@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method)
self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
......
......@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
......@@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
......@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
......@@ -236,7 +237,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
XverseDecoderLayer(config, linear_method)
XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = XverseModel(config, linear_method)
self.quant_config = quant_config
self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment