Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6206dcb2
Unverified
Commit
6206dcb2
authored
Jul 06, 2024
by
Roger Wang
Committed by
GitHub
Jul 07, 2024
Browse files
[Model] Add PaliGemma (#5189)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
93893800
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
557 additions
and
2 deletions
+557
-2
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
examples/paligemma_example.py
examples/paligemma_example.py
+52
-0
tests/models/test_paligemma.py
tests/models/test_paligemma.py
+147
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+8
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+344
-0
No files found.
docs/source/models/supported_models.rst
View file @
6206dcb2
...
...
@@ -186,6 +186,10 @@ Vision Language Models
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
...
...
examples/paligemma_example.py
0 → 100644
View file @
6206dcb2
import
os
import
subprocess
from
PIL
import
Image
from
vllm
import
LLM
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def
run_paligemma
():
llm
=
LLM
(
model
=
"google/paligemma-3b-mix-224"
)
prompt
=
"caption es"
image
=
Image
.
open
(
"images/stop_sign.jpg"
)
outputs
=
llm
.
generate
({
"prompt"
:
prompt
,
"multi_modal_data"
:
{
"image"
:
image
},
})
for
o
in
outputs
:
generated_text
=
o
.
outputs
[
0
].
text
print
(
generated_text
)
def
main
():
run_paligemma
()
if
__name__
==
"__main__"
:
# Download from s3
s3_bucket_path
=
"s3://air-example-data-2/vllm_opensource_llava/"
local_directory
=
"images"
# Make sure the local directory exists or create it
os
.
makedirs
(
local_directory
,
exist_ok
=
True
)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess
.
check_call
([
"aws"
,
"s3"
,
"sync"
,
s3_bucket_path
,
local_directory
,
"--no-sign-request"
,
])
main
()
tests/models/test_paligemma.py
0 → 100644
View file @
6206dcb2
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
from
transformers
import
AutoTokenizer
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"caption es"
,
"cherry_blossom"
:
"What is in the picture?"
,
"boardwalk"
:
"What is in the picture?"
,
})
IMAGE_TOKEN_ID
=
257152
models
=
[
"google/paligemma-3b-mix-224"
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
model
:
str
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
eos_token_id
=
tokenizer
.
eos_token_id
hf_output_ids
=
[
token_id
for
idx
,
token_id
in
enumerate
(
output_ids
)
if
token_id
!=
IMAGE_TOKEN_ID
or
output_ids
[
idx
-
1
]
!=
IMAGE_TOKEN_ID
]
hf_output_str
=
output_str
if
hf_output_ids
[
-
1
]
==
eos_token_id
:
hf_output_str
=
hf_output_str
+
tokenizer
.
decode
(
eos_token_id
)
return
hf_output_ids
,
hf_output_str
,
out_logprobs
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
model
:
str
,
*
,
size_factors
:
List
[
float
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs_per_image
]
with
hf_runner
(
model
,
dtype
=
dtype
,
is_vision_model
=
True
)
as
hf_model
:
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs_per_image
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
model
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
=
size_factors
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/model_executor/models/__init__.py
View file @
6206dcb2
...
...
@@ -49,6 +49,8 @@ _GENERATION_MODELS = {
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
...
...
vllm/model_executor/models/gemma.py
View file @
6206dcb2
...
...
@@ -268,16 +268,22 @@ class GemmaModel(nn.Module):
normalizer
=
self
.
config
.
hidden_size
**
0.5
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
normalizer
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
...
...
vllm/model_executor/models/paligemma.py
0 → 100644
View file @
6206dcb2
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
torch
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
,
SiglipVisionConfig
,
SiglipVisionModel
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma
import
GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
SamplerOutput
,
SequenceData
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.model"
:
"language_model"
,
}
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
text_config
=
hf_config
.
text_config
return
text_config
.
num_image_tokens
def
dummy_seq_data_for_paligemma
(
hf_config
:
PaliGemmaConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
hf_config
.
text_config
.
num_image_tokens
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_paligemma
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
seq_data
=
dummy_seq_data_for_paligemma
(
hf_config
,
seq_len
,
image_token_id
=
hf_config
.
image_token_index
,
)
mm_data
=
dummy_image_for_paligemma
(
vision_config
)
return
seq_data
,
mm_data
def
input_processor_for_paligemma
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '
\n
'
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
"""
# noqa
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
image_feature_size
=
hf_config
.
text_config
.
num_image_tokens
image_token_str
=
tokenizer
.
decode
(
hf_config
.
image_token_index
)
bos_token
=
tokenizer
.
decode
(
hf_config
.
bos_token_id
)
image_token_str_pad
=
image_token_str
*
image_feature_size
image_token_ids_pad
=
[
hf_config
.
image_token_index
]
*
image_feature_size
orig_prompt
=
llm_inputs
.
get
(
"prompt"
)
orig_prompt_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
)
if
image_token_str
in
orig_prompt
:
logger
.
warning
(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace."
,
image_token_str
)
orig_prompt
=
orig_prompt
.
replace
(
image_token_str
,
""
)
orig_prompt_ids
.
remove
(
hf_config
.
image_token_index
)
new_prompt
=
f
"
{
image_token_str_pad
}{
bos_token
}{
orig_prompt
}
\n
"
new_token_ids
=
image_token_ids_pad
+
orig_prompt_ids
+
[
108
]
#newline
# NOTE: Create a defensive copy of the original inputs
return
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
class
PaliGemmaMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
projection_dim
:
int
):
super
().
__init__
()
self
.
linear
=
ColumnParallelLinear
(
vision_hidden_size
,
projection_dim
,
bias
=
True
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
linear
(
image_features
)
return
hidden_states
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs
=
PaliGemmaImagePixelInputs
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_paligemma_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_paligemma
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_paligemma
)
class
PaliGemmaForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
def
__init__
(
self
,
config
:
PaliGemmaConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# TODO(ywang96): Port over SiglipVisionModel & TP
self
.
vision_tower
=
SiglipVisionModel
(
config
.
vision_config
)
self
.
multi_modal_projector
=
PaliGemmaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
projection_dim
=
config
.
vision_config
.
projection_dim
)
self
.
quant_config
=
quant_config
self
.
language_model
=
GemmaModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
actual_dims
=
tuple
(
data
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"batch_size"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
f
"The expected shape of pixel values is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
data
.
shape
)
}
."
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
PaliGemmaImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
PaliGemmaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
selected_image_features
=
image_outputs
.
last_hidden_state
return
selected_image_features
def
_process_image_pixels
(
self
,
inputs
:
PaliGemmaImagePixelInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
def
_process_image_input
(
self
,
image_input
:
PaliGemmaImageInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
return
self
.
multi_modal_projector
(
image_features
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
)
->
SamplerOutput
:
parsed_image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
parsed_image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
parsed_image_input
)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings
=
vision_embeddings
*
(
self
.
config
.
hidden_size
**
-
0.5
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
# Copied from vllm/model_executor/models/gemma.py
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
language_model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
# Copied from vllm/model_executor/models/gemma.py
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
# Adapted from vllm/model_executor/models/gemma.py
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment