Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7025b11d
Unverified
Commit
7025b11d
authored
Aug 13, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 13, 2024
Browse files
[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
parent
5469146b
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
106 additions
and
43 deletions
+106
-43
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+5
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+5
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+5
-2
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+5
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+5
-2
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+5
-2
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+5
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+5
-2
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+5
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+5
-2
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+5
-2
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+5
-2
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+5
-2
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+5
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+5
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+5
-2
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+11
-5
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+5
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+5
-2
No files found.
vllm/model_executor/models/deepseek.py
View file @
7025b11d
...
@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
7025b11d
...
@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/falcon.py
View file @
7025b11d
...
@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
...
@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
)
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/fuyu.py
View file @
7025b11d
...
@@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
...
@@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
)
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
language_model
.
logits_processor
(
logits
=
self
.
language_model
.
logits_processor
(
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/gemma.py
View file @
7025b11d
...
@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/gemma2.py
View file @
7025b11d
...
@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/gpt2.py
View file @
7025b11d
...
@@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module):
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
7025b11d
...
@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
...
@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/gpt_j.py
View file @
7025b11d
...
@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
...
@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
return
logits
...
...
vllm/model_executor/models/gpt_neox.py
View file @
7025b11d
...
@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
embed_out
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
embed_out
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/internlm2.py
View file @
7025b11d
...
@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
output
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
output
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/internvl.py
View file @
7025b11d
...
@@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
...
@@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
sampling_metadata
)
...
...
vllm/model_executor/models/jais.py
View file @
7025b11d
...
@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
...
@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/jamba.py
View file @
7025b11d
...
@@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
))
device
=
"cuda"
))
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/llama.py
View file @
7025b11d
...
@@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
)
return
model_output
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/llava.py
View file @
7025b11d
...
@@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
sampling_metadata
)
...
...
vllm/model_executor/models/llava_next.py
View file @
7025b11d
...
@@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
sampling_metadata
)
...
...
vllm/model_executor/models/medusa.py
View file @
7025b11d
...
@@ -65,22 +65,28 @@ class Medusa(nn.Module):
...
@@ -65,22 +65,28 @@ class Medusa(nn.Module):
def
compute_logits
(
def
compute_logits
(
self
,
hidden_states
:
List
[
torch
.
Tensor
],
self
,
hidden_states
:
List
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
)
->
List
[
torch
.
Tensor
]:
sampling_metadata
:
SamplingMetadata
)
->
List
[
torch
.
Tensor
]:
logits
=
[]
logits
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
hs
,
lm_head
in
zip
(
hidden_states
,
self
.
lm_heads
):
for
hs
,
lm_head
in
zip
(
hidden_states
,
self
.
lm_heads
):
_logits
=
self
.
logits_processor
(
lm_head
,
hs
,
sampling_metadata
)
_logits
=
self
.
logits_processor
(
lm_head
,
hs
,
sampling_metadata
)
if
_logits
is
None
:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert
len
(
logits_lst
)
==
0
continue
if
self
.
token_map
is
None
:
if
self
.
token_map
is
None
:
logits
.
append
(
_logits
)
logits
_lst
.
append
(
_logits
)
else
:
else
:
logits
.
append
(
-
torch
.
inf
*
torch
.
ones
(
logits
_lst
.
append
(
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
))
dtype
=
_logits
.
dtype
))
logits
[
-
1
][...,
self
.
token_map
]
=
_logits
logits
_lst
[
-
1
][...,
self
.
token_map
]
=
_logits
return
logits
return
logits
_lst
def
sample
(
def
sample
(
self
,
self
,
...
...
vllm/model_executor/models/minicpm.py
View file @
7025b11d
...
@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
hidden_states
=
hidden_states
/
self
.
scale_width
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
lm_head
=
self
.
model
.
embed_tokens
lm_head
=
self
.
model
.
embed_tokens
...
...
vllm/model_executor/models/minicpmv.py
View file @
7025b11d
...
@@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
...
@@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
)
)
return
output
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment