Unverified Commit a62aaf1d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

parent 603ad848
...@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig ...@@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module): ...@@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module): ...@@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
use_sliding_window: bool = False, use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None: sliding_window: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module): ...@@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
layer_idx: int, layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
use_sliding_window=use_sliding_window, use_sliding_window=use_sliding_window,
linear_method=linear_method, quant_config=quant_config,
sliding_window=config.sliding_window) sliding_window=config.sliding_window)
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module): ...@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module): ...@@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, linear_method) Qwen2DecoderLayer(config, layer_idx, quant_config)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config del lora_config
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = Qwen2Model(config, linear_method) self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight self.lm_head_weight = self.model.embed_tokens.weight
......
...@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module): ...@@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
reduce_results=reduce_results) reduce_results=reduce_results)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
Qwen2MoeMLP(hidden_size=config.hidden_size, Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False) reduce_results=False)
for idx in range(self.n_routed_experts) for idx in range(self.n_routed_experts)
]) ])
...@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts, self.n_routed_experts,
bias=False, bias=False,
linear_method=None) quant_config=None)
if config.shared_expert_intermediate_size > 0: if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP( self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size, intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False, reduce_results=False,
) )
else: else:
...@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module): ...@@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_idx: int, layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
if (config.num_experts is not None if (config.num_experts is not None
and (layer_idx + 1) % config.decoder_sparse_step == 0): and (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config, self.mlp = Qwen2MoeSparseMoeBlock(config=config,
linear_method=linear_method) quant_config=quant_config)
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config, Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
layer_idx,
linear_method=linear_method)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = Qwen2MoeModel(config, linear_method) self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -28,11 +28,12 @@ from transformers import PretrainedConfig ...@@ -28,11 +28,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module): ...@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module): ...@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2, config.hidden_size, [config.intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size, self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
...@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module): ...@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module): ...@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_key_value_heads, self.total_num_key_value_heads,
self.qkv_bias, self.qkv_bias,
linear_method=linear_method) quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_ndims, rotary_dim=self.rotary_ndims,
...@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module): ...@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config) self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method) self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
...@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module): ...@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method) StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
...@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module): ...@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = StableLMEpochModel(config, linear_method) self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module): ...@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module): ...@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=self.use_bias, bias=self.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=self.use_bias, bias=self.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module): ...@@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.c_fc = ColumnParallelLinear( self.c_fc = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=config.use_bias, bias=config.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=config.use_bias, bias=config.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None) quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size) config.intermediate_size)
...@@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
linear_method=linear_method) self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module): ...@@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -202,7 +202,7 @@ class Starcoder2Model(nn.Module): ...@@ -202,7 +202,7 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method) Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
...@@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method) self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig ...@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -52,17 +53,17 @@ class XverseMLP(nn.Module): ...@@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -85,7 +86,7 @@ class XverseAttention(nn.Module): ...@@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
...@@ -112,13 +113,13 @@ class XverseAttention(nn.Module): ...@@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
bias=getattr(config, "bias", False), bias=getattr(config, "bias", False),
sliding_window=sliding_window, sliding_window=sliding_window,
) )
...@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -220,7 +221,7 @@ class XverseModel(nn.Module): ...@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -236,7 +237,7 @@ class XverseModel(nn.Module): ...@@ -236,7 +237,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
XverseDecoderLayer(config, linear_method) XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module): ...@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config=None, lora_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = XverseModel(config, linear_method) self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment