Unverified Commit 6b2427f9 authored by Chen Xia's avatar Chen Xia Committed by GitHub
Browse files

[Quantization]add prefix for commandA quantized model (#17017)

parent b07d7416
...@@ -89,6 +89,7 @@ class CohereMLP(nn.Module): ...@@ -89,6 +89,7 @@ class CohereMLP(nn.Module):
self, self,
config: CohereConfig, config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -99,12 +100,14 @@ class CohereMLP(nn.Module): ...@@ -99,12 +100,14 @@ class CohereMLP(nn.Module):
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -158,12 +161,14 @@ class CohereAttention(nn.Module): ...@@ -158,12 +161,14 @@ class CohereAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -244,7 +249,9 @@ class CohereDecoderLayer(nn.Module): ...@@ -244,7 +249,9 @@ class CohereDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
self.mlp = CohereMLP(config, quant_config=quant_config) self.mlp = CohereMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment