Unverified Commit 399c7986 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Remove ScaledActivation for AWQ (#10057)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent 406d4cc4
...@@ -50,9 +50,6 @@ class Int8TpuConfig(QuantizationConfig): ...@@ -50,9 +50,6 @@ class Int8TpuConfig(QuantizationConfig):
return TPUInt8LinearMethod(self) return TPUInt8LinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class TPUInt8LinearMethod(LinearMethodBase): class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """ """Int8 Linear method for TPU Quant. """
......
...@@ -393,8 +393,7 @@ class BartEncoderLayer(nn.Module): ...@@ -393,8 +393,7 @@ class BartEncoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config)
ffn_hidden_size = self.embed_dim ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim ffn_intermediate_size = config.encoder_ffn_dim
...@@ -405,7 +404,7 @@ class BartEncoderLayer(nn.Module): ...@@ -405,7 +404,7 @@ class BartEncoderLayer(nn.Module):
bias=ffn_has_bias, bias=ffn_has_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
ffn_intermediate_size, ffn_intermediate_size,
ffn_hidden_size, ffn_hidden_size,
...@@ -473,8 +472,7 @@ class BartDecoderLayer(nn.Module): ...@@ -473,8 +472,7 @@ class BartDecoderLayer(nn.Module):
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
''' '''
......
...@@ -146,7 +146,7 @@ class BloomMLP(nn.Module): ...@@ -146,7 +146,7 @@ class BloomMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
......
...@@ -212,7 +212,7 @@ class FalconMLP(nn.Module): ...@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
quant_config=quant_config) quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
......
...@@ -123,8 +123,7 @@ class GPT2MLP(nn.Module): ...@@ -123,8 +123,7 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
......
...@@ -135,8 +135,7 @@ class GPTBigMLP(nn.Module): ...@@ -135,8 +135,7 @@ class GPTBigMLP(nn.Module):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
......
...@@ -130,8 +130,7 @@ class GPTJMLP(nn.Module): ...@@ -130,8 +130,7 @@ class GPTJMLP(nn.Module):
hidden_size, hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states) hidden_states, _ = self.fc_in(hidden_states)
......
...@@ -128,8 +128,7 @@ class GPTNeoXMLP(nn.Module): ...@@ -128,8 +128,7 @@ class GPTNeoXMLP(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act)
config.intermediate_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
......
...@@ -153,7 +153,7 @@ class MPTMLP(nn.Module): ...@@ -153,7 +153,7 @@ class MPTMLP(nn.Module):
bias=not config.no_bias, bias=not config.no_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn("gelu", quant_config, intermediate_size) self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
......
...@@ -147,8 +147,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -147,8 +147,7 @@ class OPTDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
) )
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.ffn_dim, config.ffn_dim,
self.embed_dim, self.embed_dim,
......
...@@ -60,7 +60,7 @@ class PersimmonMLP(nn.Module): ...@@ -60,7 +60,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.act = get_act_fn(config.hidden_act, quant_config) self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
......
...@@ -152,7 +152,7 @@ class PhiMLP(nn.Module): ...@@ -152,7 +152,7 @@ class PhiMLP(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, n_inner) self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
......
...@@ -203,7 +203,7 @@ class QwenVMLP(nn.Module): ...@@ -203,7 +203,7 @@ class QwenVMLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config)
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) self.act_fn = get_act_fn("gelu")
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
......
...@@ -139,8 +139,7 @@ class Starcoder2MLP(nn.Module): ...@@ -139,8 +139,7 @@ class Starcoder2MLP(nn.Module):
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act)
config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment