Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
93832447
Commit
93832447
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
44e952ce
Pipeline
#1149
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
228 additions
and
0 deletions
+228
-0
src/llmfactory/extras/misc.py
src/llmfactory/extras/misc.py
+228
-0
No files found.
src/llmfactory/extras/misc.py
0 → 100644
View file @
93832447
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Dict
,
Tuple
import
torch
from
peft
import
PeftModel
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
,
PreTrainedModel
from
transformers.utils
import
(
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_torch_bf16_gpu_available
,
is_torch_cuda_available
,
is_torch_mps_available
,
is_torch_npu_available
,
is_torch_xpu_available
,
)
from
transformers.utils.versions
import
require_version
from
.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
.logging
import
get_logger
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
try
:
_is_bf16_available
=
is_torch_bf16_gpu_available
()
except
Exception
:
_is_bf16_available
=
False
if
TYPE_CHECKING
:
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
class
AverageMeter
:
r
"""
Computes and stores the average and current value.
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
check_dependencies
()
->
None
:
if
os
.
environ
.
get
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
else
:
require_version
(
"transformers>=4.37.2"
,
"To fix: pip install transformers>=4.37.2"
)
require_version
(
"datasets>=2.14.3"
,
"To fix: pip install datasets>=2.14.3"
)
require_version
(
"accelerate>=0.27.2"
,
"To fix: pip install accelerate>=0.27.2"
)
require_version
(
"peft>=0.10.0"
,
"To fix: pip install peft>=0.10.0"
)
require_version
(
"trl>=0.8.1"
,
"To fix: pip install trl>=0.8.1"
)
def
count_parameters
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]:
r
"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params
,
all_param
=
0
,
0
for
param
in
model
.
parameters
():
num_params
=
param
.
numel
()
# if using DS Zero 3 and the weights are initialized empty
if
num_params
==
0
and
hasattr
(
param
,
"ds_numel"
):
num_params
=
param
.
ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if
param
.
__class__
.
__name__
==
"Params4bit"
:
if
hasattr
(
param
,
"quant_storage"
)
and
hasattr
(
param
.
quant_storage
,
"itemsize"
):
num_bytes
=
param
.
quant_storage
.
itemsize
elif
hasattr
(
param
,
"element_size"
):
# for older pytorch version
num_bytes
=
param
.
element_size
()
else
:
num_bytes
=
1
num_params
=
num_params
*
2
*
num_bytes
all_param
+=
num_params
if
param
.
requires_grad
:
trainable_params
+=
num_params
return
trainable_params
,
all_param
def
fix_valuehead_checkpoint
(
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
)
->
None
:
r
"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if
not
isinstance
(
model
.
pretrained_model
,
(
PreTrainedModel
,
PeftModel
)):
return
if
safe_serialization
:
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
decoder_state_dict
=
{}
v_head_state_dict
=
{}
for
name
,
param
in
state_dict
.
items
():
if
name
.
startswith
(
"v_head."
):
v_head_state_dict
[
name
]
=
param
else
:
decoder_state_dict
[
name
.
replace
(
"pretrained_model."
,
""
)]
=
param
os
.
remove
(
path_to_checkpoint
)
model
.
pretrained_model
.
save_pretrained
(
output_dir
,
state_dict
=
decoder_state_dict
or
None
,
safe_serialization
=
safe_serialization
)
if
safe_serialization
:
save_file
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_SAFE_WEIGHTS_NAME
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
logger
.
info
(
"Value head model saved at: {}"
.
format
(
output_dir
))
def
get_current_device
()
->
torch
.
device
:
r
"""
Gets the current available device.
"""
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
device
=
"npu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_mps_available
():
device
=
"mps:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_cuda_available
():
device
=
"cuda:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
else
:
device
=
"cpu"
return
torch
.
device
(
device
)
def
get_device_count
()
->
int
:
r
"""
Gets the number of available GPU devices.
"""
if
not
torch
.
cuda
.
is_available
():
return
0
return
torch
.
cuda
.
device_count
()
def
get_logits_processor
()
->
"LogitsProcessorList"
:
r
"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
return
logits_processor
def
infer_optim_dtype
(
model_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
r
"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
return
torch
.
bfloat16
elif
_is_fp16_available
:
return
torch
.
float16
else
:
return
torch
.
float32
def
has_tokenized_data
(
path
:
os
.
PathLike
)
->
bool
:
r
"""
Checks if the path has a tokenized dataset.
"""
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
def
torch_gc
()
->
None
:
r
"""
Collects GPU memory.
"""
gc
.
collect
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
def
try_download_model_from_ms
(
model_args
:
"ModelArguments"
)
->
str
:
if
not
use_modelscope
()
or
os
.
path
.
exists
(
model_args
.
model_name_or_path
):
return
model_args
.
model_name_or_path
try
:
from
modelscope
import
snapshot_download
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
)
except
ImportError
:
raise
ImportError
(
"Please install modelscope via `pip install modelscope -U`"
)
def
use_modelscope
()
->
bool
:
return
os
.
environ
.
get
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment