Unverified Commit 8f783c19 authored by eigen's avatar eigen Committed by GitHub
Browse files

[Model Support] unsloth/Phi-4-mini bnb model (#4982)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 90faf901
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
import itertools
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset): ...@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard( def adjust_bitsandbytes_4bit_shard(
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"] total, _ = shard_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id] orig_offset, orig_size = shard_offsets[loaded_shard_id]
quantized_total = param.data.shape[0] quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total quantized_offset = orig_offset * quantized_total // total
...@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization. # Special case for Quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
...@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param, shard_size, shard_offset param, shard_size, shard_offset
) )
if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(shard_id)
)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size output_dim, shard_offset, shard_size
) )
......
...@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module): ...@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
column_parallel_weights_modules = [".down_proj.", ".o_proj."] column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), ".q_proj": (".qkv_proj", 0),
"k_proj": ("qkv_proj", 1), ".k_proj": (".qkv_proj", 1),
"v_proj": ("qkv_proj", 2), ".v_proj": (".qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0), ".gate_proj": (".gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1), ".up_proj": (".gate_up_proj", 1),
} }
def __init__( def __init__(
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestUnslothPhi4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.78)
class TestUnslothPhi4Bnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)
class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)
class TestUnslothPhi4MiniInstruct(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.65)
class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)
class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment