Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2789112
Unverified
Commit
b2789112
authored
Apr 25, 2025
by
Woosuk Kwon
Committed by
GitHub
Apr 25, 2025
Browse files
[Minor][Models] Fix Return Types of Llama & Eagle (#17220)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7bd0c774
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
5 deletions
+6
-5
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-1
vllm/model_executor/models/llama_eagle.py
vllm/model_executor/models/llama_eagle.py
+2
-2
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+2
-2
No files found.
vllm/model_executor/models/llama.py
View file @
b2789112
...
@@ -345,7 +345,8 @@ class LlamaModel(nn.Module):
...
@@ -345,7 +345,8 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
,
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
...
vllm/model_executor/models/llama_eagle.py
View file @
b2789112
...
@@ -70,7 +70,7 @@ class LlamaModel(nn.Module):
...
@@ -70,7 +70,7 @@ class LlamaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
fc
(
hidden_states
=
self
.
fc
(
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
...
@@ -133,7 +133,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
...
@@ -133,7 +133,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
b2789112
...
@@ -117,7 +117,7 @@ class LlamaModel(nn.Module):
...
@@ -117,7 +117,7 @@ class LlamaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
embed_tokens
(
input_ids
)
if
(
hidden_states
.
shape
[
-
1
]
!=
input_embeds
.
shape
[
-
1
]):
if
(
hidden_states
.
shape
[
-
1
]
!=
input_embeds
.
shape
[
-
1
]):
hidden_states
=
self
.
fc
(
hidden_states
)
hidden_states
=
self
.
fc
(
hidden_states
)
...
@@ -194,7 +194,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -194,7 +194,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
def
compute_logits
(
def
compute_logits
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment