Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
24e59f53
Unverified
Commit
24e59f53
authored
Mar 24, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 24, 2024
Browse files
`model_runner` simplify (#329)
parent
75235419
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
115 deletions
+51
-115
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+2
-6
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+49
-109
No files found.
python/sglang/srt/managers/router/model_rpc.py
View file @
24e59f53
...
...
@@ -407,9 +407,7 @@ class ModelRpcServer:
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
if
prefill_logprobs
is
not
None
:
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
normalized_logprobs
=
normalized_logprobs
.
cpu
().
tolist
()
...
...
@@ -496,9 +494,7 @@ class ModelRpcServer:
# Forward
logits
,
(
_
,
_
,
last_logprobs
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
,
batch
.
return_logprob
,
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
logits
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
24e59f53
...
...
@@ -367,148 +367,88 @@ class ModelRunner:
)
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
input_ids
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_logprob
,
):
def
forward_prefill
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
PREFILL
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
return_logprob
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
input_ids
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_logprob
,
):
def
forward_extend
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
return_logprob
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
input_ids
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
,
out_cache_cont_end
,
return_logprob
,
):
def
forward_decode
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
DECODE
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
return_logprob
=
return_logprob
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_cont_start
=
batch
.
out_cache_cont_start
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
self
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
input_ids
,
pixel_values
,
image_sizes
,
image_offsets
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
return_logprob
,
):
def
forward_extend_multi_modal
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
position_ids_offsets
=
position_ids_offsets
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
return_logprob
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
)
return
self
.
model
.
forward
(
input_ids
,
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
pixel_values
,
image_sizes
,
image_offsets
,
batch
.
pixel_values
,
batch
.
image_sizes
,
batch
.
image_offsets
,
)
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
,
return_logprob
=
False
):
def
forward
(
self
,
batch
:
Batch
,
forward_mode
:
ForwardMode
):
if
self
.
is_multimodal_model
and
forward_mode
==
ForwardMode
.
EXTEND
:
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
"pixel_values"
:
batch
.
pixel_values
,
"image_sizes"
:
batch
.
image_sizes
,
"image_offsets"
:
batch
.
image_offsets
,
"req_pool_indices"
:
batch
.
req_pool_indices
,
"seq_lens"
:
batch
.
seq_lens
,
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
return
self
.
forward_extend_multi_modal
(
**
kwargs
)
else
:
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
"req_pool_indices"
:
batch
.
req_pool_indices
,
"seq_lens"
:
batch
.
seq_lens
,
"prefix_lens"
:
batch
.
prefix_lens
,
"position_ids_offsets"
:
batch
.
position_ids_offsets
,
"out_cache_loc"
:
batch
.
out_cache_loc
,
"return_logprob"
:
return_logprob
,
}
if
forward_mode
==
ForwardMode
.
DECODE
:
kwargs
[
"out_cache_cont_start"
]
=
batch
.
out_cache_cont_start
kwargs
[
"out_cache_cont_end"
]
=
batch
.
out_cache_cont_end
return
self
.
forward_decode
(
**
kwargs
)
return
self
.
forward_extend_multi_modal
(
batch
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
forward_decode
(
batch
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend
(
**
kwargs
)
return
self
.
forward_extend
(
batch
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
return
self
.
forward_prefill
(
**
kwargs
)
return
self
.
forward_prefill
(
batch
)
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment