Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2d96da81
Unverified
Commit
2d96da81
authored
Jul 19, 2024
by
Ying Sheng
Committed by
GitHub
Jul 19, 2024
Browse files
refactor model loader [unreachable code]: initial refactor (#655)
parent
c126a6cc
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2116 additions
and
0 deletions
+2116
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+869
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+49
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+662
-0
python/sglang/srt/model_loader/model_loader.py
python/sglang/srt/model_loader/model_loader.py
+276
-0
python/sglang/srt/model_loader/utils.py
python/sglang/srt/model_loader/utils.py
+260
-0
No files found.
python/sglang/srt/layers/linear.py
0 → 100644
View file @
2d96da81
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/quantization/__init__.py
0 → 100644
View file @
2d96da81
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
"QuantizationConfig"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
python/sglang/srt/layers/quantization/fp8.py
0 → 100644
View file @
2d96da81
This diff is collapsed.
Click to expand it.
python/sglang/srt/model_loader/model_loader.py
0 → 100644
View file @
2d96da81
# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
# FIXME: in progress of refactoring the model loader
import
glob
import
os
import
re
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
torch
import
nn
from
tqdm
import
tqdm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
,
)
from
vllm.platforms
import
current_platform
from
sglang.srt.model_loader.utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
get_quant_config
,
safetensors_weights_iterator
,
)
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
return
quant_config
return
None
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
assert
lora_config
is
None
assert
multimodal_config
is
None
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
),
)
class
ModelLoader
:
"""Model loader that can load different file types from disk."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
,
fall_back_to_pt
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
)
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
return
weights_iterator
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
,
)
weights
=
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
)
modules
=
{}
for
name
,
module
in
model
.
named_modules
():
modules
[
name
]
=
module
def
apply_quant_method
(
module
):
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# print("before apply quant", module.weight, module.weight.dtype)
quant_method
.
process_weights_after_loading
(
module
)
# print("after apply quant", module.weight, module.weight.dtype)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
if
torch
.
cuda
.
current_device
()
==
0
:
weights
=
tqdm
(
weights
,
total
=
model
.
get_num_params
()
*
1.5
,
desc
=
"load model"
)
num_shard
=
{}
num_loaded
=
{}
for
name
,
loaded_weight
in
weights
:
model
.
load_weights
(
None
,
name
,
loaded_weight
)
module_name
,
shard_num
=
model
.
get_module_name
(
name
)
num_shard
[
module_name
]
=
shard_num
if
module_name
not
in
num_loaded
:
num_loaded
[
module_name
]
=
1
else
:
num_loaded
[
module_name
]
+=
1
if
num_loaded
[
module_name
]
==
num_shard
[
module_name
]:
apply_quant_method
(
modules
[
module_name
])
return
model
.
eval
()
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
loader
=
ModelLoader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
multimodal_config
=
multimodal_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
,
)
python/sglang/srt/model_loader/utils.py
0 → 100644
View file @
2d96da81
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
"""Utilities for selecting and loading models."""
import
contextlib
import
fnmatch
import
hashlib
import
json
import
logging
import
os
import
tempfile
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
import
filelock
import
huggingface_hub.constants
import
torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
get_quantization_config
logger
=
logging
.
getLogger
(
"srt.model_loader"
)
temp_dir
=
tempfile
.
gettempdir
()
@
contextlib
.
contextmanager
def
set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
!=
"fp8"
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
class
DisabledTqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
or
temp_dir
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
Returns:
str: The path to the downloaded model weights.
"""
if
not
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
return
hf_folder
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
try
:
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
os
.
path
.
join
(
hf_folder
,
weight_map
[
weight_name
]))
# Filter out any fields that are not found in the index file.
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
f
in
weight_files_in_index
]
return
hf_weights_files
def
safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
None
:
# compressed-tensors uses a compressions_config
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"compression_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if
model_config
.
quantization
==
"bitsandbytes"
:
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
return
quant_cls
.
from_config
({
"adapter_name_or_path"
:
""
})
model_name_or_path
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
else
:
model_name_or_path
=
model_config
.
model
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
tqdm_class
=
DisabledTqdm
,
)
else
:
hf_folder
=
model_name_or_path
possible_config_filenames
=
quant_cls
.
get_config_filenames
()
# If the quantization config is not found, use the default config.
if
not
possible_config_filenames
:
return
quant_cls
()
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
possible_config_filenames
)
]
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
return
quant_cls
.
from_config
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment