Commit 581d366d authored by chenych's avatar chenych
Browse files

Support GLM-4/GLM-4-0414/GLM-Z1

parent 428c5813
......@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"template": "llama3",
"infer_dtype": "float16",
}
......@@ -37,7 +37,7 @@ def fix_valuehead_cpu_loading():
def test_base():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA)
ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model)
......
......@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
......@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"finetuning_type": "freeze",
"template": "llama3",
"infer_dtype": "float16",
......
......@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
......@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"finetuning_type": "full",
"template": "llama3",
"infer_dtype": "float16",
......
......@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import (
)
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
......@@ -48,7 +48,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
......@@ -81,13 +81,13 @@ def test_lora_train_extra_modules():
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
)
......@@ -105,5 +105,5 @@ def test_lora_train_valuehead():
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
compare_model(model, ref_model)
......@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
......
......@@ -27,10 +27,10 @@ from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
......@@ -41,6 +41,7 @@ TRAIN_ARGS = {
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
"report_to": "none",
}
......
# change if test fails
0.9.3.101
0.9.3.102
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment