Commit fac1af55 authored by EC2 Default User's avatar EC2 Default User
Browse files

updated max_new_tokens_key

parent 86fcf708
......@@ -3,7 +3,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock"
max_new_tokens_key = "max_position_embeddings" # check this
max_new_tokens_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTJForCausalLM):
......@@ -26,7 +26,7 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = []
# attention input
# attention input + linear 1
layers.append(dict(
prev_op=module.ln_1,
layers=[module.attn.q_proj,
......@@ -37,18 +37,13 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
))
# attention out
# for some reason falcon skips this too
layers.append(dict(
prev_op=module.attn.v_proj,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# Linear 1 is included in the attention input
# GPTJ uses a parallel Attn + MLP block so they share an input
# linear 2
# Falcon doesn't use this - maybe we don't need this
layers.append(dict(
prev_op=module.mlp.act,
layers=[module.mlp.fc_out],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment