Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36d5acfc
Unverified
Commit
36d5acfc
authored
Sep 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 30, 2024
Browse files
Rename InputMetadata -> ForwardBatch (#1543)
parent
3f0fe08d
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
194 additions
and
194 deletions
+194
-194
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+12
-12
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+10
-10
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+10
-10
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+10
-10
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+10
-10
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+10
-10
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+10
-10
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+10
-10
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+5
-5
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+4
-4
python/sglang/srt/models/llama_reward.py
python/sglang/srt/models/llama_reward.py
+8
-8
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+11
-11
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+11
-11
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+10
-10
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+12
-12
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+10
-10
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+10
-10
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+10
-10
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+10
-10
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+11
-11
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
36d5acfc
...
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_hip
# ROCm: flashinfer available later
...
...
@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
...
@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module):
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
256
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
256
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
256
)[
...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
...
...
@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
...
...
@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
input_metadata
)
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
...
...
@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/exaone.py
View file @
36d5acfc
...
...
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
ExaoneGatedMLP
(
nn
.
Module
):
...
...
@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
,
input_embeds
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/gemma.py
View file @
36d5acfc
...
...
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
GemmaMLP
(
nn
.
Module
):
...
...
@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/gemma2.py
View file @
36d5acfc
...
...
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
# Aligned with HF's implementation, using sliding window inclusive with the last token
...
...
@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
...
...
@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
...
...
@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
forward_batch
)
def
get_attention_sliding_window_size
(
self
):
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
36d5acfc
...
...
@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
(
...
...
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim
=-
1
,
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
input_metadata
)
hidden_states
=
layer
(
hidden_states
,
forward_batch
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/grok.py
View file @
36d5acfc
...
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
Grok1MoE
(
nn
.
Module
):
...
...
@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
# Self Attention
hidden_states
=
(
...
...
@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
self
.
pre_attn_norm
(
hidden_states
),
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
)
+
hidden_states
...
...
@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
hidden_states
=
input_embeds
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
.
mul_
(
self
.
config
.
output_multiplier_scale
)
return
hidden_states
...
...
@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/internlm2.py
View file @
36d5acfc
...
...
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
InternLM2MLP
(
nn
.
Module
):
...
...
@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
wo
(
attn_output
)
return
output
...
...
@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
hidden_states
=
self
.
attention
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
output
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/llama.py
View file @
36d5acfc
...
...
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
...
...
python/sglang/srt/models/llama_classification.py
View file @
36d5acfc
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
...
@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
is_eos_token
=
input_ids
==
self
.
eos_token_id
hidden_states
=
hidden_states
[
is_eos_token
]
scores
=
self
.
classification_head
(
hidden_states
)
if
scores
.
shape
[
0
]
!=
input_metadata
.
batch_size
:
if
scores
.
shape
[
0
]
!=
forward_batch
.
batch_size
:
print
(
"Warning: the EOS tokens are missing in some sentences."
)
scores
=
torch
.
ones
(
(
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)
(
forward_batch
.
batch_size
,
self
.
config
.
classification_out_size
)
).
to
(
input_ids
.
device
)
logits_output
=
LogitsProcessorOutput
(
...
...
python/sglang/srt/models/llama_embedding.py
View file @
36d5acfc
...
...
@@ -6,7 +6,7 @@ from transformers import LlamaConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaModel
...
...
@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
...
...
python/sglang/srt/models/llama_reward.py
View file @
36d5acfc
...
...
@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
...
@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
EmbeddingPoolerOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
scores
=
self
.
score
(
hidden_states
)
return
self
.
pooler
(
scores
,
input_metadata
)
return
self
.
pooler
(
scores
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaForSequenceClassification is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
logits
=
self
.
score
(
hidden_states
)
weights
=
self
.
weights
(
hidden_states
)
pooled_logits
=
self
.
pooler
(
logits
,
input_metadata
).
embeddings
pooled_weights
=
self
.
pooler
(
weights
,
input_metadata
).
embeddings
pooled_logits
=
self
.
pooler
(
logits
,
forward_batch
).
embeddings
pooled_weights
=
self
.
pooler
(
weights
,
forward_batch
).
embeddings
rews
=
pooled_logits
.
view
(
-
1
,
self
.
num_labels
//
2
,
2
)[:,
:,
0
].
view
(
-
1
,
self
.
num_labels
//
2
...
...
python/sglang/srt/models/llava.py
View file @
36d5acfc
...
...
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
...
@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
image_inputs
=
forward_batch
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
if
forward_batch
.
forward_mode
.
is_extend
():
bs
=
forward_batch
.
batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list
=
[]
...
...
@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
...
...
@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
image_features
=
new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
input_metadata
.
extend_prefix_lens
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
...
...
@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
input_ids
,
positions
,
forward_batch
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
elif
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
language_model
(
input_ids
,
positions
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# Load clip vision model by cfg['mm_vision_tower']:
...
...
python/sglang/srt/models/llavavid.py
View file @
36d5acfc
...
...
@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
image_inputs
=
forward_batch
.
image_inputs
if
forward_batch
.
forward_mode
.
is_extend
():
bs
=
forward_batch
.
batch_size
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
...
...
@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module):
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
else
:
max_image_offset
.
append
(
-
1
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
...
...
@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
image_features
=
new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
input_metadata
.
extend_prefix_lens
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
...
...
@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
input_ids
,
positions
,
forward_batch
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
elif
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
language_model
(
input_ids
,
positions
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# Load clip vision model by cfg['mm_vision_tower']:
...
...
python/sglang/srt/models/minicpm.py
View file @
36d5acfc
...
...
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
MiniCPMMLP
(
nn
.
Module
):
...
...
@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
q
.
to
(
orig_dtype
),
k
.
to
(
orig_dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
hidden_states
=
residual
+
hidden_states
*
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
)
...
...
@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
not
None
:
input_embeds
=
input_embeds
*
self
.
config
.
scale_emb
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
lm_head_weight
=
self
.
lm_head
.
weight
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/minicpm3.py
View file @
36d5acfc
...
...
@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_hip
# ROCm: flashinfer available later
...
...
@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
...
@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
128
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
128
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
128
)[
...,
:
self
.
v_head_dim
].
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
...
...
@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
...
...
@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
input_metadata
)
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
...
...
@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
hidden_states
=
residual
+
hidden_states
*
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
)
...
...
@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
not
None
:
input_embeds
=
input_embeds
*
self
.
config
.
scale_emb
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
lm_head_weight
=
self
.
lm_head
.
weight
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/mixtral.py
View file @
36d5acfc
...
...
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
MixtralMoE
(
nn
.
Module
):
...
...
@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
36d5acfc
...
...
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
MixtralMLP
(
nn
.
Module
):
...
...
@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/olmoe.py
View file @
36d5acfc
...
...
@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
OlmoeMoE
(
nn
.
Module
):
...
...
@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/qwen.py
View file @
36d5acfc
...
...
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
QWenMLP
(
nn
.
Module
):
...
...
@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
...
...
@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
...
...
@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
hidden_states
=
self
.
attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
...
...
@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
hidden_states
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
python/sglang/srt/models/qwen2.py
View file @
36d5acfc
...
...
@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
Qwen2Config
=
None
...
...
@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
forward_batch
=
forward_batch
,
)
# Fully Connected
...
...
@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
...
...
@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
input_metadata
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment