Unverified Commit de11e654 authored by Karl Hajjar's avatar Karl Hajjar Committed by GitHub
Browse files

Fix max_position_embeddings default value for llama2 to 4096 #28241 (#28754)



* Changed max_position_embeddings default value from 2048 to 4096

* force push

* Fixed formatting issues. Fixed missing argument in write_model.

* Reverted to the default value 2048 in the Llama config. Added comments for the llama_version argument.

* Fixed issue with default value value of max_position_embeddings in docstring

* Updated help message for llama versions
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 2749e479
...@@ -80,7 +80,9 @@ def write_json(text, path): ...@@ -80,7 +80,9 @@ def write_json(text, path):
json.dump(text, f) json.dump(text, f)
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): def write_model(
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size` # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")): if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size) input_base_path = os.path.join(input_base_path, model_size)
...@@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa ...@@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
if base > 10000.0: if base > 10000.0:
max_position_embeddings = 16384 max_position_embeddings = 16384
else: else:
max_position_embeddings = 2048 # Depending on the Llama version, the default max_position_embeddings has different values.
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {llama_version} of llama is not supported yet. "
"Current supported versions of llama are [1, 2]."
)
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if tokenizer_path is not None: if tokenizer_path is not None:
...@@ -301,6 +312,14 @@ def main(): ...@@ -301,6 +312,14 @@ def main():
help="Location to write HF model and tokenizer", help="Location to write HF model and tokenizer",
) )
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=[1, 2],
default=1,
type=int,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
args = parser.parse_args() args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model") spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only": if args.model_size != "tokenizer_only":
...@@ -310,6 +329,7 @@ def main(): ...@@ -310,6 +329,7 @@ def main():
model_size=args.model_size, model_size=args.model_size,
safe_serialization=args.safe_serialization, safe_serialization=args.safe_serialization,
tokenizer_path=spm_path, tokenizer_path=spm_path,
llama_version=args.llama_version,
) )
else: else:
write_tokenizer(args.output_dir, spm_path) write_tokenizer(args.output_dir, spm_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment