Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
581d366d
Commit
581d366d
authored
Apr 15, 2025
by
chenych
Browse files
Support GLM-4/GLM-4-0414/GLM-Z1
parent
428c5813
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
21 additions
and
20 deletions
+21
-20
tests/model/test_base.py
tests/model/test_base.py
+3
-3
tests/model/test_freeze.py
tests/model/test_freeze.py
+3
-3
tests/model/test_full.py
tests/model/test_full.py
+3
-3
tests/model/test_lora.py
tests/model/test_lora.py
+6
-6
tests/model/test_pissa.py
tests/model/test_pissa.py
+2
-2
tests/train/test_sft_trainer.py
tests/train/test_sft_trainer.py
+3
-2
tests/version.txt
tests/version.txt
+1
-1
No files found.
tests/model/test_base.py
View file @
581d366d
...
...
@@ -19,12 +19,12 @@ import pytest
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
,
patch_valuehead_model
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_VALUEHEAD
=
os
.
getenv
(
"TINY_LLAMA_VALUEHEAD"
,
"llamafactory/tiny-random-Llama-3-valuehead"
)
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
}
...
...
@@ -37,7 +37,7 @@ def fix_valuehead_cpu_loading():
def
test_base
():
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA
)
ref_model
=
load_reference_model
(
TINY_LLAMA
3
)
compare_model
(
model
,
ref_model
)
...
...
tests/model/test_freeze.py
View file @
581d366d
...
...
@@ -19,10 +19,10 @@ import torch
from
llamafactory.train.test_utils
import
load_infer_model
,
load_train_model
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"freeze"
,
...
...
@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"finetuning_type"
:
"freeze"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
...
...
tests/model/test_full.py
View file @
581d366d
...
...
@@ -19,10 +19,10 @@ import torch
from
llamafactory.train.test_utils
import
load_infer_model
,
load_train_model
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"full"
,
...
...
@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"finetuning_type"
:
"full"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
...
...
tests/model/test_lora.py
View file @
581d366d
...
...
@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import (
)
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_ADAPTER
=
os
.
getenv
(
"TINY_LLAMA_ADAPTER"
,
"llamafactory/tiny-random-Llama-3-lora"
)
TINY_LLAMA_VALUEHEAD
=
os
.
getenv
(
"TINY_LLAMA_VALUEHEAD"
,
"llamafactory/tiny-random-Llama-3-valuehead"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
...
...
@@ -48,7 +48,7 @@ TRAIN_ARGS = {
}
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"adapter_name_or_path"
:
TINY_LLAMA_ADAPTER
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
...
...
@@ -81,13 +81,13 @@ def test_lora_train_extra_modules():
def
test_lora_train_old_adapters
():
model
=
load_train_model
(
adapter_name_or_path
=
TINY_LLAMA_ADAPTER
,
create_new_adapter
=
False
,
**
TRAIN_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
,
is_trainable
=
True
)
ref_model
=
load_reference_model
(
TINY_LLAMA
3
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
,
is_trainable
=
True
)
compare_model
(
model
,
ref_model
)
def
test_lora_train_new_adapters
():
model
=
load_train_model
(
adapter_name_or_path
=
TINY_LLAMA_ADAPTER
,
create_new_adapter
=
True
,
**
TRAIN_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
,
is_trainable
=
True
)
ref_model
=
load_reference_model
(
TINY_LLAMA
3
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
,
is_trainable
=
True
)
compare_model
(
model
,
ref_model
,
diff_keys
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
,
"up_proj"
,
"gate_proj"
,
"down_proj"
]
)
...
...
@@ -105,5 +105,5 @@ def test_lora_train_valuehead():
def
test_lora_inference
():
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
).
merge_and_unload
()
ref_model
=
load_reference_model
(
TINY_LLAMA
3
,
TINY_LLAMA_ADAPTER
,
use_lora
=
True
).
merge_and_unload
()
compare_model
(
model
,
ref_model
)
tests/model/test_pissa.py
View file @
581d366d
...
...
@@ -19,12 +19,12 @@ import pytest
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
,
load_train_model
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_PISSA
=
os
.
getenv
(
"TINY_LLAMA_ADAPTER"
,
"llamafactory/tiny-random-Llama-3-pissa"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
...
...
tests/train/test_sft_trainer.py
View file @
581d366d
...
...
@@ -27,10 +27,10 @@ from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA
3
=
os
.
getenv
(
"TINY_LLAMA
3
"
,
"llamafactory/tiny-random-Llama-3"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"model_name_or_path"
:
TINY_LLAMA
3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
...
...
@@ -41,6 +41,7 @@ TRAIN_ARGS = {
"overwrite_output_dir"
:
True
,
"per_device_train_batch_size"
:
1
,
"max_steps"
:
1
,
"report_to"
:
"none"
,
}
...
...
tests/version.txt
View file @
581d366d
# change if test fails
0.9.3.10
1
0.9.3.10
2
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment