Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7025b11d
Unverified
Commit
7025b11d
authored
Aug 13, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 13, 2024
Browse files
[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
parent
5469146b
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
95 additions
and
39 deletions
+95
-39
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+5
-2
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+5
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+5
-2
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+5
-2
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+5
-2
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+5
-2
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+5
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+5
-2
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+5
-2
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+5
-2
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+5
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+5
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+5
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+5
-2
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+5
-2
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+5
-2
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+5
-2
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+5
-2
vllm/outputs.py
vllm/outputs.py
+5
-3
No files found.
vllm/model_executor/models/mixtral.py
View file @
7025b11d
...
...
@@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
7025b11d
...
...
@@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/mpt.py
View file @
7025b11d
...
...
@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/nemotron.py
View file @
7025b11d
...
...
@@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/olmo.py
View file @
7025b11d
...
...
@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/opt.py
View file @
7025b11d
...
...
@@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/orion.py
View file @
7025b11d
...
...
@@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/paligemma.py
View file @
7025b11d
...
...
@@ -262,8 +262,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
# Copied from vllm/model_executor/models/gemma.py
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/persimmon.py
View file @
7025b11d
...
...
@@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/phi.py
View file @
7025b11d
...
...
@@ -286,8 +286,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
...
...
vllm/model_executor/models/phi3_small.py
View file @
7025b11d
...
...
@@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
def
get_decoder
(
self
):
return
self
.
model
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
dummy_token_indices
is
not
None
and
logits
is
not
None
:
...
...
vllm/model_executor/models/phi3v.py
View file @
7025b11d
...
...
@@ -584,8 +584,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/qwen.py
View file @
7025b11d
...
...
@@ -281,8 +281,11 @@ class QWenLMHeadModel(nn.Module):
device
=
device
),
})
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/qwen2.py
View file @
7025b11d
...
...
@@ -362,8 +362,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
7025b11d
...
...
@@ -400,8 +400,11 @@ class Qwen2MoeForCausalLM(nn.Module):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/stablelm.py
View file @
7025b11d
...
...
@@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/starcoder2.py
View file @
7025b11d
...
...
@@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/xverse.py
View file @
7025b11d
...
...
@@ -328,8 +328,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/outputs.py
View file @
7025b11d
import
time
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
...
@@ -28,7 +30,7 @@ class CompletionOutput:
index
:
int
text
:
str
token_ids
:
Tuple
[
int
,
...
]
token_ids
:
GenericSequence
[
int
]
cumulative_logprob
:
Optional
[
float
]
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
...
...
@@ -139,7 +141,7 @@ class RequestOutput:
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
data
.
_output_token_ids
,
# type: ignore
seq
.
data
.
_output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment