Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
e628f110
"backend/vscode:/vscode.git/clone" did not exist on "3fce09eb3d46b2007701e580069b583ca4564ebd"
Commit
e628f110
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
0efd8580
Pipeline
#1177
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
0 deletions
+58
-0
src/llmfactory/model/utils/valuehead.py
src/llmfactory/model/utils/valuehead.py
+58
-0
No files found.
src/llmfactory/model/utils/valuehead.py
0 → 100644
View file @
e628f110
from
typing
import
TYPE_CHECKING
,
Dict
import
torch
from
transformers.utils
import
cached_file
from
...extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
r
"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs
=
{
"path_or_repo_id"
:
path_or_repo_id
,
"cache_dir"
:
model_args
.
cache_dir
,
"token"
:
model_args
.
hf_hub_token
}
try
:
from
safetensors
import
safe_open
vhead_file
=
cached_file
(
filename
=
V_HEAD_SAFE_WEIGHTS_NAME
,
**
kwargs
)
with
safe_open
(
vhead_file
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
return
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
except
Exception
as
err
:
logger
.
info
(
"Failed to load {}: {}"
.
format
(
V_HEAD_SAFE_WEIGHTS_NAME
,
str
(
err
)))
try
:
vhead_file
=
cached_file
(
filename
=
V_HEAD_WEIGHTS_NAME
,
**
kwargs
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
)
except
Exception
as
err
:
logger
.
info
(
"Failed to load {}: {}"
.
format
(
V_HEAD_WEIGHTS_NAME
,
str
(
err
)))
logger
.
info
(
"Provided path ({}) does not contain value head weights."
.
format
(
path_or_repo_id
))
logger
.
info
(
"Ignore these messages if you are not resuming the training of a value head model."
)
return
None
def
prepare_valuehead_model
(
model
:
"PreTrainedModel"
)
->
None
:
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"llava"
:
setattr
(
model
,
"lm_head"
,
model
.
language_model
.
get_output_embeddings
())
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"chatglm"
:
setattr
(
model
,
"lm_head"
,
model
.
transformer
.
output_layer
)
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"internlm2"
:
setattr
(
model
,
"lm_head"
,
model
.
output
)
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment