Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
628 additions
and
0 deletions
+628
-0
tests_v1/plugins/model_plugins/test_kernel_plugin.py
tests_v1/plugins/model_plugins/test_kernel_plugin.py
+72
-0
tests_v1/plugins/model_plugins/test_peft.py
tests_v1/plugins/model_plugins/test_peft.py
+156
-0
tests_v1/plugins/model_plugins/test_quantization_plugin.py
tests_v1/plugins/model_plugins/test_quantization_plugin.py
+51
-0
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
+104
-0
tests_v1/sampler/test_cli_sampler.py
tests_v1/sampler/test_cli_sampler.py
+44
-0
tests_v1/trainers/test_fsdp2_sft_trainer.py
tests_v1/trainers/test_fsdp2_sft_trainer.py
+89
-0
tests_v1/utils/test_batching_queue.py
tests_v1/utils/test_batching_queue.py
+112
-0
No files found.
tests_v1/plugins/model_plugins/test_kernel_plugin.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
transformers
import
AutoModelForCausalLM
from
llamafactory.v1.accelerator.helper
import
get_current_accelerator
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_accelerator_cache
():
get_current_accelerator
.
cache_clear
()
def
reload_kernels
():
"""Helper to reload kernel modules to respect mocked accelerator."""
# Unload kernel interface and registry
keys_to_remove
=
[
k
for
k
in
sys
.
modules
if
k
.
startswith
(
"llamafactory.v1.plugins.model_plugins.kernels"
)]
for
k
in
keys_to_remove
:
del
sys
.
modules
[
k
]
@
patch
(
"torch.accelerator.current_accelerator"
)
def
test_apply_kernel
(
mock_get_accelerator
:
MagicMock
):
mock_device
=
MagicMock
()
setattr
(
mock_device
,
"type"
,
"npu"
)
mock_get_accelerator
.
return_value
=
mock_device
# Force reload of kernels with mocked accelerator
reload_kernels
()
from
llamafactory.v1.plugins.model_plugins.kernels.interface
import
apply_default_kernels
model
=
AutoModelForCausalLM
.
from_pretrained
(
"llamafactory/tiny-random-qwen2.5"
)
original_rmsnorm_forward
=
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
original_swiglu_forward
=
model
.
model
.
layers
[
0
].
mlp
.
forward
model
=
apply_default_kernels
(
model
=
model
,
include_kernels
=
"npu_fused_rmsnorm"
)
assert
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
.
__func__
is
not
original_rmsnorm_forward
.
__func__
assert
model
.
model
.
layers
[
0
].
mlp
.
forward
.
__func__
is
original_swiglu_forward
.
__func__
@
patch
(
"torch.accelerator.current_accelerator"
)
def
test_apply_all_kernels
(
mock_get_accelerator
:
MagicMock
):
get_current_accelerator
.
cache_clear
()
mock_device
=
MagicMock
()
setattr
(
mock_device
,
"type"
,
"npu"
)
mock_get_accelerator
.
return_value
=
mock_device
# Force reload of kernels with mocked accelerator
reload_kernels
()
from
llamafactory.v1.plugins.model_plugins.kernels.interface
import
apply_default_kernels
model
=
AutoModelForCausalLM
.
from_pretrained
(
"llamafactory/tiny-random-qwen2.5"
)
original_rmsnorm_forward
=
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
original_swiglu_forward
=
model
.
model
.
layers
[
0
].
mlp
.
forward
model
=
apply_default_kernels
(
model
=
model
,
include_kernels
=
True
)
assert
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
.
__func__
is
not
original_rmsnorm_forward
.
__func__
assert
model
.
model
.
layers
[
0
].
mlp
.
forward
.
__func__
is
not
original_swiglu_forward
.
__func__
tests_v1/plugins/model_plugins/test_peft.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
peft
import
LoraConfig
,
PeftModel
,
get_peft_model
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
llamafactory.v1.plugins.model_plugins
import
peft
as
peft_module
from
llamafactory.v1.plugins.model_plugins.peft
import
merge_and_export_model
TINY_MODEL
=
"llamafactory/tiny-random-qwen3"
@
pytest
.
fixture
(
scope
=
"module"
)
def
model_path
():
return
TINY_MODEL
@
pytest
.
fixture
(
scope
=
"function"
)
def
model
(
model_path
):
return
AutoModelForCausalLM
.
from_pretrained
(
model_path
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
tokenizer
(
model_path
):
return
AutoTokenizer
.
from_pretrained
(
model_path
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
adapter_path
(
tmp_path
):
# Create a dummy adapter
lora_config
=
LoraConfig
(
r
=
8
,
lora_alpha
=
16
,
target_modules
=
[
"q_proj"
,
"v_proj"
],
lora_dropout
=
0.05
,
bias
=
"none"
,
task_type
=
"CAUSAL_LM"
,
)
base_model
=
AutoModelForCausalLM
.
from_pretrained
(
TINY_MODEL
)
peft_model
=
get_peft_model
(
base_model
,
lora_config
)
save_path
=
tmp_path
/
"test_adapter"
peft_model
.
save_pretrained
(
save_path
)
return
str
(
save_path
)
def
test_find_all_linear_modules
(
model
):
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
modules
=
peft_module
.
_find_all_linear_modules
(
model
)
expected_subset
=
{
"q_proj"
,
"v_proj"
}
assert
expected_subset
.
issubset
(
set
(
modules
))
def
test_get_lora_model
(
model
):
"""Verify a PeftModel is returned and LoRA config takes effect."""
config
=
{
"name"
:
"lora"
,
"r"
:
8
,
"target_modules"
:
"all"
,
"lora_alpha"
:
16
}
model
=
peft_module
.
get_lora_model
(
model
,
config
,
is_train
=
True
)
assert
isinstance
(
model
,
PeftModel
)
assert
model
.
peft_config
[
"default"
].
r
==
8
assert
"q_proj"
in
model
.
peft_config
[
"default"
].
target_modules
def
test_get_freeze_model_layers
(
model
):
"""Verify layer-wise freezing: only the last layer stays trainable."""
# Freeze all but last layer
config
=
{
"name"
:
"freeze"
,
"freeze_trainable_layers"
:
1
,
"freeze_trainable_modules"
:
"all"
}
# Ensure we start with something known
model
=
peft_module
.
get_freeze_model
(
model
,
config
,
is_train
=
True
)
num_layers
=
model
.
config
.
num_hidden_layers
assert
num_layers
>
0
for
name
,
param
in
model
.
named_parameters
():
if
f
"layers.
{
num_layers
-
1
}
"
in
name
:
assert
param
.
requires_grad
,
f
"
{
name
}
should be trainable"
elif
"layers.0"
in
name
and
num_layers
>
1
:
assert
not
param
.
requires_grad
,
f
"
{
name
}
should be frozen"
def
test_get_freeze_model_modules
(
model
):
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
# Freeze specific modules (e.g. only self_attn)
config
=
{
"name"
:
"freeze"
,
"freeze_trainable_layers"
:
1
,
"freeze_trainable_modules"
:
"self_attn"
}
model
=
peft_module
.
get_freeze_model
(
model
,
config
,
is_train
=
True
)
num_layers
=
model
.
config
.
num_hidden_layers
for
name
,
param
in
model
.
named_parameters
():
if
f
"layers.
{
num_layers
-
1
}
"
in
name
and
"self_attn"
in
name
:
assert
param
.
requires_grad
,
f
"
{
name
}
should be trainable"
else
:
assert
not
param
.
requires_grad
,
f
"
{
name
}
should be frozen"
def
test_load_adapter_single_for_inference
(
model
,
adapter_path
):
"""Verify single adapter is merged+unloaded in inference mode."""
# Test loading single adapter for inference (merge and unload)
model_result
=
peft_module
.
load_adapter
(
model
,
adapter_path
,
is_train
=
False
)
assert
not
isinstance
(
model_result
,
PeftModel
)
def
test_load_adapter_resume_train
(
model
,
adapter_path
):
"""Verify training mode returns a trainable PeftModel."""
# Test loading for training
model_result
=
peft_module
.
load_adapter
(
model
,
adapter_path
,
is_train
=
True
)
assert
isinstance
(
model_result
,
PeftModel
)
def
test_load_adapter_train_multiple_disallowed
(
model
,
adapter_path
):
"""Verify multiple adapters are rejected in training mode."""
with
pytest
.
raises
(
ValueError
,
match
=
"only a single LoRA adapter"
):
peft_module
.
load_adapter
(
model
,
[
adapter_path
,
adapter_path
],
is_train
=
True
)
def
test_load_adapter_infer_multiple_merges
(
model
,
adapter_path
):
"""Verify multiple adapters are merged in inference mode."""
# Test merging multiple adapters
model_result
=
peft_module
.
load_adapter
(
model
,
[
adapter_path
,
adapter_path
],
is_train
=
False
)
assert
not
isinstance
(
model_result
,
PeftModel
)
def
test_merge_and_export_model
(
tmp_path
,
adapter_path
):
"""Verify merge_and_export_model produces export artifacts."""
export_dir
=
tmp_path
/
"export"
args_dict
=
{
"model"
:
TINY_MODEL
,
"peft_config"
:
{
"name"
:
"lora"
,
"adapter_name_or_path"
:
adapter_path
,
"export_dir"
:
str
(
export_dir
),
"export_size"
:
1
,
"infer_dtype"
:
"float16"
,
},
}
merge_and_export_model
(
args_dict
)
assert
export_dir
.
exists
()
assert
(
export_dir
/
"config.json"
).
exists
()
assert
(
export_dir
/
"model.safetensors"
).
exists
()
assert
(
export_dir
/
"tokenizer_config.json"
).
exists
()
tests_v1/plugins/model_plugins/test_quantization_plugin.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
llamafactory.v1.config.model_args
import
ModelArguments
from
llamafactory.v1.core.model_engine
import
ModelEngine
bitsandbytes
=
pytest
.
importorskip
(
"bitsandbytes"
)
def
check_quantization_status
(
model
):
quantized_info
=
{
"bnb"
:
[]}
for
name
,
module
in
model
.
named_modules
():
# check BitsAndBytes quantization
if
isinstance
(
module
,
bitsandbytes
.
nn
.
modules
.
Linear8bitLt
)
or
isinstance
(
module
,
bitsandbytes
.
nn
.
modules
.
Linear4bit
):
quantized_info
[
"bnb"
].
append
(
name
)
return
quantized_info
@
pytest
.
mark
.
runs_on
([
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"name, quantization_bit"
,
[(
"bnb"
,
4
),
(
"auto"
,
4
)])
def
test_quantization_plugin
(
name
,
quantization_bit
):
model_args
=
ModelArguments
(
model
=
"llamafactory/tiny-random-qwen3"
,
quant_config
=
{
"name"
:
name
,
"quantization_bit"
:
quantization_bit
,
},
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
quantized_info
=
check_quantization_status
(
model_engine
.
model
)
print
(
f
"Quantized weights for method
{
name
}
with
{
quantization_bit
}
bit:
{
quantized_info
}
"
)
assert
any
(
v
for
v
in
quantized_info
.
values
()),
"model is not quantized properly."
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
Validates that the FSDP2 meta loading path behaves correctly for tied weights
and non-persistent buffers by comparing it with the standard non-meta path.
"""
import
torch
from
transformers
import
AutoConfig
from
llamafactory.v1.accelerator.interface
import
DistributedInterface
from
llamafactory.v1.config.arg_parser
import
get_args
from
llamafactory.v1.core.model_engine
import
ModelEngine
from
llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2
import
FSDP2Engine
TINY_MODEL
=
"llamafactory/tiny-random-qwen3"
def
collect_non_persistent_buffers
(
model
):
"""Collect all non-persistent buffers from model."""
result
=
{}
for
mod_name
,
module
in
model
.
named_modules
():
for
buf_name
in
getattr
(
module
,
"_non_persistent_buffers_set"
,
set
()):
fqn
=
f
"
{
mod_name
}
.
{
buf_name
}
"
if
mod_name
else
buf_name
buf
=
getattr
(
module
,
buf_name
,
None
)
if
buf
is
not
None
:
result
[
fqn
]
=
buf
.
detach
().
cpu
().
clone
()
return
result
def
test_fsdp2_meta_loading_buffers_and_tied_weights
():
"""Verify non-persistent buffers and tied weights consistency after meta load."""
# 1. Initialize DistributedInterface for single process
DistributedInterface
()
# 2. Build FSDP2Engine config
engine
=
FSDP2Engine
(
{
"name"
:
"fsdp2"
,
"mixed_precision"
:
"bf16"
,
"reshard_after_forward"
:
True
,
"offload_params"
:
False
,
"pin_memory"
:
False
,
"dcp_path"
:
None
,
}
)
config
=
AutoConfig
.
from_pretrained
(
TINY_MODEL
)
# --- NORMAL PATH ---
normal_args
,
*
_
=
get_args
(
dict
(
model
=
TINY_MODEL
,
init_config
=
None
))
normal_engine
=
ModelEngine
(
model_args
=
normal_args
)
normal_model
=
normal_engine
.
model
.
to
(
torch
.
bfloat16
)
normal_model
=
engine
.
shard_model
(
normal_model
)
normal_non_persistent
=
collect_non_persistent_buffers
(
normal_model
)
del
normal_model
# --- META PATH ---
meta_args
,
*
_
=
get_args
(
dict
(
model
=
TINY_MODEL
,
init_config
=
{
"name"
:
"init_on_meta"
}))
meta_model_engine
=
ModelEngine
(
model_args
=
meta_args
)
meta_model
=
meta_model_engine
.
model
assert
meta_model
.
device
.
type
==
"meta"
,
"Model should be on meta device"
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
meta_model
=
engine
.
shard_model
(
meta_model
)
meta_non_persistent
=
collect_non_persistent_buffers
(
meta_model
)
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
tie_word_embeddings
=
getattr
(
config
,
"tie_word_embeddings"
,
False
)
if
tie_word_embeddings
:
assert
meta_model
.
lm_head
.
weight
is
meta_model
.
model
.
embed_tokens
.
weight
,
(
"Weights should be tied after loading"
)
del
meta_model
# 4. Non-persistent buffers (e.g., inv_freq)
normal_buf_keys
=
set
(
normal_non_persistent
.
keys
())
meta_buf_keys
=
set
(
meta_non_persistent
.
keys
())
assert
normal_buf_keys
==
meta_buf_keys
,
"Non-persistent buffer keys mismatch"
for
key
in
sorted
(
normal_buf_keys
&
meta_buf_keys
):
nb
=
normal_non_persistent
[
key
]
mb
=
meta_non_persistent
[
key
]
assert
nb
.
shape
==
mb
.
shape
,
f
"Buffer shape mismatch:
{
key
}
"
assert
torch
.
allclose
(
nb
.
float
(),
mb
.
float
(),
atol
=
1e-5
),
f
"Buffer value mismatch:
{
key
}
"
tests_v1/sampler/test_cli_sampler.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
llamafactory.v1.config
import
ModelArguments
,
SampleArguments
from
llamafactory.v1.core.model_engine
import
ModelEngine
from
llamafactory.v1.samplers.cli_sampler
import
SyncSampler
@
pytest
.
mark
.
runs_on
([
"cuda"
,
"npu"
])
def
test_sync_sampler
():
model_args
=
ModelArguments
(
model
=
"Qwen/Qwen3-4B-Instruct-2507"
,
template
=
"qwen3_nothink"
)
sample_args
=
SampleArguments
()
model_engine
=
ModelEngine
(
model_args
)
sampler
=
SyncSampler
(
sample_args
,
model_args
,
model_engine
.
model
,
model_engine
.
renderer
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"Say 'This is a test.'"
}]}]
response
=
""
for
new_text
in
sampler
.
generate
(
messages
):
response
+=
new_text
print
(
response
)
assert
model_engine
.
renderer
.
parse_message
(
response
)
==
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"This is a test."
}],
}
if
__name__
==
"__main__"
:
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler
()
tests_v1/trainers/test_fsdp2_sft_trainer.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
from
pathlib
import
Path
import
pytest
@
pytest
.
mark
.
xfail
(
reason
=
"CI machines may OOM when heavily loaded."
)
@
pytest
.
mark
.
runs_on
([
"cuda"
,
"npu"
])
def
test_fsdp2_sft_trainer
(
tmp_path
:
Path
):
"""Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
config_yaml
=
"""
\
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
quant_config: null
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: {output_dir}
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 1
### sample
sample_backend: hf
max_new_tokens: 128
"""
# Create output directory
output_dir
=
tmp_path
/
"outputs"
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
config_file
=
tmp_path
/
"config.yaml"
config_file
.
write_text
(
config_yaml
.
format
(
output_dir
=
str
(
output_dir
)))
# Set up environment variables
env
=
os
.
environ
.
copy
()
env
[
"USE_V1"
]
=
"1"
# Use v1 launcher
env
[
"FORCE_TORCHRUN"
]
=
"1"
# Force distributed training via torchrun
# Run the CLI command via subprocess
# This simulates: llamafactory-cli sft config.yaml
result
=
subprocess
.
run
(
[
sys
.
executable
,
"-m"
,
"llamafactory.cli"
,
"sft"
,
str
(
config_file
)],
env
=
env
,
capture_output
=
True
,
cwd
=
str
(
Path
(
__file__
).
parent
.
parent
.
parent
),
# LLaMA-Factory root
)
# Decode output with error handling (progress bars may contain non-UTF-8 bytes)
stderr
=
result
.
stderr
.
decode
(
"utf-8"
,
errors
=
"replace"
)
# Check the result
assert
result
.
returncode
==
0
,
f
"Training failed with return code
{
result
.
returncode
}
\n
STDERR:
{
stderr
}
"
# Verify output files exist (optional - adjust based on what run_sft produces)
# assert (output_dir / "some_expected_file").exists()
tests_v1/utils/test_batching_queue.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
llamafactory.v1.utils.batching_queue
import
DynamicBatchSizeBuffer
,
TextBatchingQueue
def
create_sample
(
length
:
int
):
"""Helper to create a mock sample with a specific token length."""
return
{
"input_ids"
:
torch
.
ones
(
length
),
"attention_mask"
:
torch
.
ones
(
length
)}
class
TestDynamicBatchSizeBuffer
:
def
test_append_and_token_count
(
self
):
buffer
=
DynamicBatchSizeBuffer
()
buffer
.
append
(
create_sample
(
10
))
buffer
.
append
(
create_sample
(
20
))
assert
len
(
buffer
)
==
2
assert
buffer
.
total_token_count
==
30
def
test_get_samples_within_budget
(
self
):
buffer
=
DynamicBatchSizeBuffer
()
buffer
.
append
(
create_sample
(
10
))
buffer
.
append
(
create_sample
(
10
))
buffer
.
append
(
create_sample
(
50
))
# This one is large
# Request 25 tokens. Should get the first two (20 tokens total)
samples
=
buffer
.
get_samples
(
max_tokens_per_iteration
=
25
)
assert
len
(
samples
)
==
2
def
test_force_return_first_sample
(
self
):
buffer
=
DynamicBatchSizeBuffer
()
buffer
.
append
(
create_sample
(
100
))
# Even though budget is 50, force=True (default) should return the 100-token sample
samples
=
buffer
.
get_samples
(
max_tokens_per_iteration
=
50
,
force
=
True
)
assert
len
(
samples
)
==
1
assert
len
(
samples
[
0
][
"input_ids"
])
==
100
def
test_flush_removes_used_samples
(
self
):
buffer
=
DynamicBatchSizeBuffer
()
buffer
.
append
(
create_sample
(
10
))
buffer
.
append
(
create_sample
(
20
))
# Take the first sample
buffer
.
get_samples
(
max_tokens_per_iteration
=
15
)
buffer
.
flush
()
assert
len
(
buffer
)
==
1
assert
buffer
.
total_token_count
==
20
# The remaining sample should now be at the start
remaining
=
buffer
.
get_samples
(
max_tokens_per_iteration
=
50
)
assert
len
(
remaining
[
0
][
"input_ids"
])
==
20
class
TestTextBatchingQueue
:
def
test_is_full_filled
(
self
):
queue
=
TextBatchingQueue
(
token_micro_bsz
=
100
,
buffer_size
=
2
)
queue
.
put_item
(
create_sample
(
10
))
assert
not
queue
.
is_full_filled
()
# Only 1 sample, buffer_size=2
queue
.
put_item
(
create_sample
(
10
))
assert
not
queue
.
is_full_filled
()
# 2 samples, but only 20 tokens (min 100)
queue
.
put_item
(
create_sample
(
90
))
assert
queue
.
is_full_filled
()
# Meets both conditions
def
test_warmup_logic
(
self
):
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
queue
=
TextBatchingQueue
(
token_micro_bsz
=
1000
,
bsz_warmup_steps
=
10
,
bsz_warmup_init_mbtoken
=
200
)
# Step 0: should be init value
assert
queue
.
get_cur_token_micro_bsz
()
==
200
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
queue
.
_step
=
5
assert
queue
.
get_cur_token_micro_bsz
()
==
600
# Step 11: past warmup
queue
.
_step
=
11
assert
queue
.
get_cur_token_micro_bsz
()
==
1000
def
test_get_micro_batch_integration
(
self
):
queue
=
TextBatchingQueue
(
token_micro_bsz
=
50
,
buffer_size
=
1
)
queue
.
put_item
(
create_sample
(
20
))
queue
.
put_item
(
create_sample
(
20
))
queue
.
put_item
(
create_sample
(
20
))
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
# it should take samples up to 50 tokens.
batch
=
queue
.
get_micro_batch
(
step
=
0
)
assert
len
(
batch
)
==
2
assert
queue
.
empty
()
is
False
batch_2
=
queue
.
get_micro_batch
(
step
=
1
)
assert
len
(
batch_2
)
==
1
assert
queue
.
empty
()
is
True
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment