Commit fac1af55 authored by EC2 Default User's avatar EC2 Default User
Browse files

updated max_new_tokens_key

parent 86fcf708
...@@ -3,7 +3,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock ...@@ -3,7 +3,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM): class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock" layer_type = "GPTJBlock"
max_new_tokens_key = "max_position_embeddings" # check this max_new_tokens_key = "n_positions"
@staticmethod @staticmethod
def get_model_layers(model: GPTJForCausalLM): def get_model_layers(model: GPTJForCausalLM):
...@@ -26,7 +26,7 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM): ...@@ -26,7 +26,7 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input + linear 1
layers.append(dict( layers.append(dict(
prev_op=module.ln_1, prev_op=module.ln_1,
layers=[module.attn.q_proj, layers=[module.attn.q_proj,
...@@ -37,18 +37,13 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM): ...@@ -37,18 +37,13 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
)) ))
# attention out # attention out
# for some reason falcon skips this too
layers.append(dict( layers.append(dict(
prev_op=module.attn.v_proj, prev_op=module.attn.v_proj,
layers=[module.attn.out_proj], layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'], inp=input_feat['attn.out_proj'],
)) ))
# Linear 1 is included in the attention input
# GPTJ uses a parallel Attn + MLP block so they share an input
# linear 2 # linear 2
# Falcon doesn't use this - maybe we don't need this
layers.append(dict( layers.append(dict(
prev_op=module.mlp.act, prev_op=module.mlp.act,
layers=[module.mlp.fc_out], layers=[module.mlp.fc_out],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment