"models/vscode:/vscode.git/clone" did not exist on "97226d97d4e65a49ba7eb5c75c98a2cb1c8633bd"
Commit 581d366d authored by chenych's avatar chenych
Browse files

Support GLM-4/GLM-4-0414/GLM-Z1

parent 428c5813
...@@ -19,12 +19,12 @@ import pytest ...@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = { INFER_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"template": "llama3", "template": "llama3",
"infer_dtype": "float16", "infer_dtype": "float16",
} }
...@@ -37,7 +37,7 @@ def fix_valuehead_cpu_loading(): ...@@ -37,7 +37,7 @@ def fix_valuehead_cpu_loading():
def test_base(): def test_base():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA) ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model) compare_model(model, ref_model)
......
...@@ -19,10 +19,10 @@ import torch ...@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = { TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"stage": "sft", "stage": "sft",
"do_train": True, "do_train": True,
"finetuning_type": "freeze", "finetuning_type": "freeze",
...@@ -36,7 +36,7 @@ TRAIN_ARGS = { ...@@ -36,7 +36,7 @@ TRAIN_ARGS = {
} }
INFER_ARGS = { INFER_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"finetuning_type": "freeze", "finetuning_type": "freeze",
"template": "llama3", "template": "llama3",
"infer_dtype": "float16", "infer_dtype": "float16",
......
...@@ -19,10 +19,10 @@ import torch ...@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = { TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"stage": "sft", "stage": "sft",
"do_train": True, "do_train": True,
"finetuning_type": "full", "finetuning_type": "full",
...@@ -36,7 +36,7 @@ TRAIN_ARGS = { ...@@ -36,7 +36,7 @@ TRAIN_ARGS = {
} }
INFER_ARGS = { INFER_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"finetuning_type": "full", "finetuning_type": "full",
"template": "llama3", "template": "llama3",
"infer_dtype": "float16", "infer_dtype": "float16",
......
...@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import ( ...@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import (
) )
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora") TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = { TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"stage": "sft", "stage": "sft",
"do_train": True, "do_train": True,
"finetuning_type": "lora", "finetuning_type": "lora",
...@@ -48,7 +48,7 @@ TRAIN_ARGS = { ...@@ -48,7 +48,7 @@ TRAIN_ARGS = {
} }
INFER_ARGS = { INFER_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER, "adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora", "finetuning_type": "lora",
"template": "llama3", "template": "llama3",
...@@ -81,13 +81,13 @@ def test_lora_train_extra_modules(): ...@@ -81,13 +81,13 @@ def test_lora_train_extra_modules():
def test_lora_train_old_adapters(): def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS) model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model) compare_model(model, ref_model)
def test_lora_train_new_adapters(): def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS) model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model( compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
) )
...@@ -105,5 +105,5 @@ def test_lora_train_valuehead(): ...@@ -105,5 +105,5 @@ def test_lora_train_valuehead():
def test_lora_inference(): def test_lora_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload() ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
compare_model(model, ref_model) compare_model(model, ref_model)
...@@ -19,12 +19,12 @@ import pytest ...@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa") TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TRAIN_ARGS = { TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"stage": "sft", "stage": "sft",
"do_train": True, "do_train": True,
"finetuning_type": "lora", "finetuning_type": "lora",
......
...@@ -27,10 +27,10 @@ from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer ...@@ -27,10 +27,10 @@ from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = { TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA, "model_name_or_path": TINY_LLAMA3,
"stage": "sft", "stage": "sft",
"do_train": True, "do_train": True,
"finetuning_type": "lora", "finetuning_type": "lora",
...@@ -41,6 +41,7 @@ TRAIN_ARGS = { ...@@ -41,6 +41,7 @@ TRAIN_ARGS = {
"overwrite_output_dir": True, "overwrite_output_dir": True,
"per_device_train_batch_size": 1, "per_device_train_batch_size": 1,
"max_steps": 1, "max_steps": 1,
"report_to": "none",
} }
......
# change if test fails # change if test fails
0.9.3.101 0.9.3.102
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment