Unverified Commit e0b617d1 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Llama conversion script: adjustments for Llama Guard (#27910)

parent e3669375
......@@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
......@@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
if "n_kv_heads" in params:
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment