Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
64fe3115
Unverified
Commit
64fe3115
authored
Mar 11, 2024
by
Geary.Z
Committed by
GitHub
Mar 10, 2024
Browse files
replace skip_embed with input_embeds (#222)
parent
a7ace9c8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
17 deletions
+17
-17
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+5
-5
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-2
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+5
-5
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+5
-5
No files found.
python/sglang/srt/models/llama2.py
View file @
64fe3115
...
@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
...
@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
skip_embed
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
hidden_states
=
input_
i
ds
hidden_states
=
input_
embe
ds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
...
@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
skip
_embed
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input
_embed
s
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
...
...
python/sglang/srt/models/llava.py
View file @
64fe3115
...
@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
pt
+=
1
pt
+=
1
return
self
.
language_model
(
return
self
.
language_model
(
input_
embe
ds
,
positions
,
input_metadata
,
skip
_embed
=
True
input_
i
ds
,
positions
,
input_metadata
,
input
_embed
s
=
input_embeds
)
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
skip_embed
=
False
input_ids
,
positions
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
...
...
python/sglang/srt/models/mixtral.py
View file @
64fe3115
...
@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
...
@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
skip_embed
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
hidden_states
=
input_
i
ds
hidden_states
=
input_
embe
ds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
...
@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
skip
_embed
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input
_embed
s
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
...
...
python/sglang/srt/models/qwen2.py
View file @
64fe3115
...
@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
...
@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
skip_embed
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
hidden_states
=
input_
i
ds
hidden_states
=
input_
embe
ds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
skip
_embed
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input
_embed
s
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment