Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
90635e34
"ppstructure/setup.py" did not exist on "b2260182e436e6cddbed3dca749eab6125e9f5b5"
Commit
90635e34
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
e628f110
Pipeline
#1178
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
0 deletions
+84
-0
src/llmfactory/model/utils/visual.py
src/llmfactory/model/utils/visual.py
+84
-0
No files found.
src/llmfactory/model/utils/visual.py
0 → 100644
View file @
90635e34
from
typing
import
TYPE_CHECKING
,
Tuple
import
torch
import
transformers.models
from
transformers.activations
import
ACT2FN
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
"LlavaConfig"
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
config
is
None
:
return
self
.
linear_1
=
torch
.
nn
.
Linear
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_2
=
torch
.
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_3
=
torch
.
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_4
=
torch
.
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
act
=
ACT2FN
[
config
.
projector_hidden_act
]
def
forward
(
self
,
image_features
:
"torch.Tensor"
)
->
"torch.Tensor"
:
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_3
(
hidden_states
)
hidden_states
=
self
.
linear_4
(
hidden_states
)
if
hidden_states
.
dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
elif
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
else
:
target_dtype
=
self
.
linear_1
.
weight
.
dtype
logger
.
warning_once
(
"The hidden states seems to be silently casted in float32."
)
hidden_states
=
hidden_states
.
to
(
target_dtype
)
return
hidden_states
class
LlavaMultiModalProjectorForYiVLForVLLM
(
LlavaMultiModalProjectorForYiVL
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
)
->
None
:
super
().
__init__
(
config
=
None
)
self
.
linear_1
=
torch
.
nn
.
Linear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
linear_2
=
torch
.
nn
.
LayerNorm
(
text_hidden_size
,
bias
=
True
)
self
.
linear_3
=
torch
.
nn
.
Linear
(
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
linear_4
=
torch
.
nn
.
LayerNorm
(
text_hidden_size
,
bias
=
True
)
self
.
act
=
ACT2FN
[
projector_hidden_act
]
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
mm_projector_name
:
str
=
"multi_modal_projector"
)
->
None
:
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
if
hasattr
(
model
,
mm_projector_name
)
and
getattr
(
model
,
"quantization_method"
,
None
):
logger
.
info
(
"Casting multimodal projector outputs in {}."
.
format
(
model_args
.
compute_dtype
))
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
mm_projector_name
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"llava"
:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment