Commit 4e0b233d authored by zhuwenwen's avatar zhuwenwen
Browse files

update lm_head weight to support llama3.2

parent aba40fda
......@@ -30,7 +30,7 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM']:
if architectures == ['BloomForCausalLM'] or architectures == ['LlamaForCausalLM']:
os.environ['LM_TN'] = '1'
else:
os.environ['LM_TN'] = '0'
......@@ -50,7 +50,7 @@ def get_model_architecture(
os.environ['AWQ_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['LM_TN'] = '0'
os.environ['LM_TN'] = '1'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment