Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c4e46433
Unverified
Commit
c4e46433
authored
Nov 18, 2024
by
Isotr0py
Committed by
GitHub
Nov 18, 2024
Browse files
[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
d1557e66
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
114 additions
and
52 deletions
+114
-52
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+6
-3
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+6
-2
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+6
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+6
-2
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+7
-2
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+5
-3
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+8
-3
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+4
-3
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+6
-2
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+6
-2
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+4
-3
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+6
-2
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+6
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+10
-5
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+4
-3
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+4
-3
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+4
-3
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+4
-3
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+6
-2
No files found.
vllm/model_executor/models/gemma2.py
View file @
c4e46433
...
...
@@ -312,7 +312,8 @@ class Gemma2Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -354,6 +355,7 @@ class Gemma2Model(nn.Module):
logger
.
warning
(
"Some weights are not initialized from checkpoints: %s"
,
unloaded_params
)
return
loaded_params
class
Gemma2ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
...
...
@@ -451,13 +453,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
class
Gemma2EmbeddingModel
(
nn
.
Module
,
SupportsPP
):
...
...
vllm/model_executor/models/gpt2.py
View file @
c4e46433
...
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -298,8 +298,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
...
...
@@ -328,3 +330,5 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/gpt_bigcode.py
View file @
c4e46433
...
...
@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -323,8 +323,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
:
continue
...
...
@@ -344,3 +346,5 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
(
param
,
loaded_weight
,
'v'
)
else
:
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/gpt_j.py
View file @
c4e46433
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -291,7 +291,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -301,6 +302,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"attn.bias"
in
name
or
"attn.masked_bias"
in
name
:
continue
...
...
@@ -330,3 +332,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/gpt_neox.py
View file @
c4e46433
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -303,8 +303,10 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
...
...
@@ -337,3 +339,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/granite.py
View file @
c4e46433
...
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -455,7 +455,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
...
@@ -465,6 +466,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -485,6 +487,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
...
...
@@ -518,6 +521,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
...
...
vllm/model_executor/models/granitemoe.py
View file @
c4e46433
...
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -419,7 +419,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
...
...
@@ -452,4 +453,5 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
pass
else
:
new_weights
[
n
]
=
p
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
return
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
vllm/model_executor/models/idefics2_vision_model.py
View file @
c4e46433
...
...
@@ -15,7 +15,7 @@
# limitations under the License.
"""PyTorch Idefics2 model."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -331,7 +331,8 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -339,11 +340,13 @@ class Idefics2VisionTransformer(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
...
...
@@ -352,3 +355,5 @@ class Idefics2VisionTransformer(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/idefics3.py
View file @
c4e46433
...
...
@@ -15,7 +15,7 @@
import
math
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
NamedTuple
,
Optional
,
Tuple
,
TypedDict
,
Union
)
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.utils.checkpoint
...
...
@@ -751,9 +751,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
...
...
vllm/model_executor/models/intern_vit.py
View file @
c4e46433
...
...
@@ -5,7 +5,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
functools
import
partial
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch.nn
as
nn
...
...
@@ -469,10 +469,14 @@ class InternVisionModel(nn.Module):
return
encoder_outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/internlm2.py
View file @
c4e46433
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -369,13 +369,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"w1"
,
0
),
(
"gate_up_proj"
,
"w3"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -402,3 +404,5 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/internvl.py
View file @
c4e46433
...
...
@@ -6,7 +6,7 @@
# --------------------------------------------------------
import
re
from
functools
import
cached_property
,
partial
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
...
...
@@ -663,6 +663,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/jais.py
View file @
c4e46433
...
...
@@ -19,7 +19,7 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -350,8 +350,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
...
...
@@ -382,3 +384,5 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/jamba.py
View file @
c4e46433
"""Inference-only Jamba model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -462,7 +462,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -479,6 +480,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -534,6 +536,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
def
_is_moe_layer
(
name
:
str
):
...
...
vllm/model_executor/models/llama.py
View file @
c4e46433
...
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -350,7 +350,8 @@ class LlamaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
...
@@ -360,6 +361,7 @@ class LlamaModel(nn.Module):
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -375,6 +377,7 @@ class LlamaModel(nn.Module):
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
...
...
@@ -390,7 +393,6 @@ class LlamaModel(nn.Module):
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
...
...
@@ -408,6 +410,8 @@ class LlamaModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
...
...
@@ -577,13 +581,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
loader
.
load_weights
(
return
loader
.
load_weights
(
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
for
name
,
loaded_weight
in
weights
)
...
...
vllm/model_executor/models/llava.py
View file @
c4e46433
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
...
...
@@ -547,6 +547,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_next.py
View file @
c4e46433
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
...
...
@@ -654,6 +654,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_next_video.py
View file @
c4e46433
import
math
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
...
...
@@ -445,10 +445,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
# This model doesn't support images for now
ignore_unexpected_prefixes
=
[
"image_newline"
],
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/llava_onevision.py
View file @
c4e46433
import
math
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
...
...
@@ -887,6 +887,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mamba.py
View file @
c4e46433
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -243,8 +243,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
...
...
@@ -256,3 +258,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment