Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
36d5acfc
Unverified
Commit
36d5acfc
authored
Sep 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 30, 2024
Browse files
Rename InputMetadata -> ForwardBatch (#1543)
parent
3f0fe08d
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
40 deletions
+40
-40
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+10
-10
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+10
-10
python/sglang/srt/models/xverse.py
python/sglang/srt/models/xverse.py
+10
-10
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+10
-10
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
36d5acfc
...
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
Qwen2MoeMLP
(
nn
.
Module
):
...
...
@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/stablelm.py
View file @
36d5acfc
...
...
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
StablelmMLP
(
nn
.
Module
):
...
...
@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
...
...
@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
...
...
@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/xverse.py
View file @
36d5acfc
...
...
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
class
XverseMLP
(
nn
.
Module
):
...
...
@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
# print(f"layer[{i}].hidden_states: {hidden_states}")
...
...
@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
...
...
python/sglang/srt/models/xverse_moe.py
View file @
36d5acfc
...
...
@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
XverseMLP
(
nn
.
Module
):
...
...
@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment