test_fsdp2_sft_trainer.py 2.54 KB
Newer Older
shihm's avatar
uodata  
shihm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path

import pytest


@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.")
@pytest.mark.runs_on(["cuda", "npu"])
def test_fsdp2_sft_trainer(tmp_path: Path):
    """Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
    config_yaml = """\
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm

template: qwen3_nothink

kernel_config:
    name: auto
    include_kernels: auto

quant_config: null

dist_config:
    name: fsdp2
    dcp_path: null

init_config:
    name: init_on_meta

### data
train_dataset: data/v1_sft_demo.yaml

### training
output_dir: {output_dir}
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 1

### sample
sample_backend: hf
max_new_tokens: 128
"""
    # Create output directory
    output_dir = tmp_path / "outputs"
    output_dir.mkdir(parents=True, exist_ok=True)
    config_file = tmp_path / "config.yaml"
    config_file.write_text(config_yaml.format(output_dir=str(output_dir)))

    # Set up environment variables
    env = os.environ.copy()
    env["USE_V1"] = "1"  # Use v1 launcher
    env["FORCE_TORCHRUN"] = "1"  # Force distributed training via torchrun

    # Run the CLI command via subprocess
    # This simulates: llamafactory-cli sft config.yaml
    result = subprocess.run(
        [sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)],
        env=env,
        capture_output=True,
        cwd=str(Path(__file__).parent.parent.parent),  # LLaMA-Factory root
    )

    # Decode output with error handling (progress bars may contain non-UTF-8 bytes)
    stderr = result.stderr.decode("utf-8", errors="replace")

    # Check the result
    assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}"

    # Verify output files exist (optional - adjust based on what run_sft produces)
    # assert (output_dir / "some_expected_file").exists()