Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
0efd8580
Commit
0efd8580
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
ca97c9b4
Pipeline
#1176
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
0 deletions
+88
-0
src/llmfactory/model/utils/unsloth.py
src/llmfactory/model/utils/unsloth.py
+88
-0
No files found.
src/llmfactory/model/utils/unsloth.py
0 → 100644
View file @
0efd8580
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
...extras.logging
import
get_logger
from
...extras.misc
import
get_current_device
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
_get_unsloth_kwargs
(
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
Any
]:
return
{
"model_name"
:
model_name_or_path
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
"dtype"
:
model_args
.
compute_dtype
,
"load_in_4bit"
:
model_args
.
quantization_bit
==
4
,
"token"
:
model_args
.
hf_hub_token
,
"device_map"
:
{
""
:
get_current_device
()},
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
),
"fix_tokenizer"
:
False
,
"trust_remote_code"
:
True
,
"use_gradient_checkpointing"
:
"unsloth"
,
}
def
load_unsloth_pretrained_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
Optional
[
"PreTrainedModel"
]:
r
"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
model_name_or_path
,
model_args
)
try
:
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
logger
.
warning
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
model
=
None
model_args
.
use_unsloth
=
False
return
model
def
get_unsloth_peft_model
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
Dict
[
str
,
Any
]
)
->
"PreTrainedModel"
:
r
"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
unsloth_peft_kwargs
=
{
"model"
:
model
,
"max_seq_length"
:
model_args
.
model_max_length
,
"use_gradient_checkpointing"
:
"unsloth"
,
}
return
FastLanguageModel
.
get_peft_model
(
**
peft_kwargs
,
**
unsloth_peft_kwargs
)
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
"PreTrainedModel"
:
r
"""
Loads peft model with unsloth. Used in both training and inference.
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
try
:
if
not
is_trainable
:
unsloth_kwargs
[
"use_gradient_checkpointing"
]
=
False
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
raise
ValueError
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
if
not
is_trainable
:
FastLanguageModel
.
for_inference
(
model
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment