Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
974e4f24
Commit
974e4f24
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
fe9a149a
Pipeline
#1168
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
0 deletions
+94
-0
src/llmfactory/model/utils/checkpointing.py
src/llmfactory/model/utils/checkpointing.py
+94
-0
No files found.
src/llmfactory/model/utils/checkpointing.py
0 → 100644
View file @
974e4f24
import
inspect
from
functools
import
partial
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
import
torch
from
...extras.constants
import
LAYERNORM_NAMES
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
_gradient_checkpointing_enable
(
self
:
"PreTrainedModel"
,
gradient_checkpointing_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
r
"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from
torch.utils.checkpoint
import
checkpoint
if
not
self
.
supports_gradient_checkpointing
:
raise
ValueError
(
"{} does not support gradient checkpointing."
.
format
(
self
.
__class__
.
__name__
))
if
gradient_checkpointing_kwargs
is
None
:
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
True
}
gradient_checkpointing_func
=
partial
(
checkpoint
,
**
gradient_checkpointing_kwargs
)
def
custom_gradient_checkpointing_func
(
func
,
*
args
,
**
kwargs
):
module
:
"torch.nn.Module"
=
func
.
__self__
if
any
(
param
.
requires_grad
for
param
in
module
.
parameters
()):
for
arg
in
args
:
if
torch
.
is_tensor
(
arg
)
and
torch
.
is_floating_point
(
arg
):
arg
.
requires_grad_
(
True
)
return
gradient_checkpointing_func
(
func
,
*
args
,
**
kwargs
)
if
"value"
in
inspect
.
signature
(
self
.
_set_gradient_checkpointing
).
parameters
:
# old GC format
self
.
apply
(
partial
(
self
.
_set_gradient_checkpointing
,
value
=
True
))
self
.
enable_input_require_grads
()
logger
.
warning
(
"You are using the old GC format, some features (e.g. BAdam) will be invalid."
)
else
:
# have already enabled input require gradients
self
.
_set_gradient_checkpointing
(
enable
=
True
,
gradient_checkpointing_func
=
custom_gradient_checkpointing_func
)
def
_fp32_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
torch
.
float32
)
def
prepare_model_for_training
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
output_layer_name
:
str
=
"lm_head"
)
->
None
:
r
"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if
model_args
.
upcast_layernorm
:
logger
.
info
(
"Upcasting layernorm weights in float32."
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
ndim
==
1
and
any
(
ln_name
in
name
for
ln_name
in
LAYERNORM_NAMES
):
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
if
not
model_args
.
disable_gradient_checkpointing
:
if
not
getattr
(
model
,
"supports_gradient_checkpointing"
,
False
):
logger
.
warning
(
"Current model does not support gradient checkpointing."
)
else
:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model
.
gradient_checkpointing_enable
=
MethodType
(
_gradient_checkpointing_enable
,
model
)
model
.
gradient_checkpointing_enable
(
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
True
})
setattr
(
model
.
config
,
"use_cache"
,
False
)
# turn off when gradient checkpointing is enabled
logger
.
info
(
"Gradient checkpointing enabled."
)
if
hasattr
(
model
,
output_layer_name
)
and
model_args
.
upcast_lmhead_output
:
logger
.
info
(
"Upcasting lm_head outputs in float32."
)
output_layer
=
getattr
(
model
,
output_layer_name
)
if
isinstance
(
output_layer
,
torch
.
nn
.
Linear
)
and
output_layer
.
weight
.
dtype
!=
torch
.
float32
:
output_layer
.
register_forward_hook
(
_fp32_forward_post_hook
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment