Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
519e20cf
Unverified
Commit
519e20cf
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609)
parent
d9a69029
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
219 additions
and
245 deletions
+219
-245
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-5
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+212
-3
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+5
-237
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
519e20cf
...
@@ -8,6 +8,7 @@ from torch import nn
...
@@ -8,6 +8,7 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.controller.infer_batch
import
global_server_args_dict
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
...
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
...
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
self
.
scaling
=
scaling
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
from
sglang.srt.managers.controller.model_runner
import
global_server_args_dict
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
...
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
...
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
PREFILL
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
prefill_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
519e20cf
...
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Store some global server args
global_server_args_dict
=
{}
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL
=
auto
()
PREFILL
=
auto
()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND
=
auto
()
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
DECODE
=
auto
()
...
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
class
Req
:
class
Req
:
"""Store all inforamtion of a request."""
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
self
.
rid
=
rid
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_text
=
origin_input_text
...
@@ -74,7 +82,7 @@ class Req:
...
@@ -74,7 +82,7 @@ class Req:
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# For incremental decod
e
# For incremental decod
ing
self
.
decoded_text
=
""
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
self
.
read_offset
=
None
...
@@ -93,9 +101,8 @@ class Req:
...
@@ -93,9 +101,8 @@ class Req:
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
stream
=
False
self
.
stream
=
False
self
.
tokenizer
=
None
# Check finish
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
# Prefix info
# Prefix info
...
@@ -252,6 +259,8 @@ class Req:
...
@@ -252,6 +259,8 @@ class Req:
@
dataclass
@
dataclass
class
Batch
:
class
Batch
:
"""Store all inforamtion of a batch."""
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
...
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
...
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
]
=
0.0
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
return
probs_sort
,
probs_idx
return
probs_sort
,
probs_idx
@
dataclass
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
# for extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
max_extend_len
:
int
=
0
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
qo_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
self
.
forward_mode
==
ForwardMode
.
EXTEND
:
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
(),
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
def
create
(
cls
,
model_runner
,
tp_size
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
other_kv_index
=
None
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
python/sglang/srt/managers/controller/model_runner.py
View file @
519e20cf
...
@@ -4,11 +4,9 @@ import importlib
...
@@ -4,11 +4,9 @@ import importlib
import
importlib.resources
import
importlib.resources
import
logging
import
logging
import
pkgutil
import
pkgutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Type
from
typing
import
Optional
,
Type
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
...
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
...
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
...
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
# for server args in model endpoints
global_server_args_dict
=
{}
@
dataclass
class
InputMetadata
:
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
# for extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
max_extend_len
:
int
=
0
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
qo_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
(),
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
def
create
(
cls
,
model_runner
,
tp_size
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
other_kv_index
=
None
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
class
ModelRunner
:
class
ModelRunner
:
def
__init__
(
def
__init__
(
...
@@ -245,6 +39,7 @@ class ModelRunner:
...
@@ -245,6 +39,7 @@ class ModelRunner:
nccl_port
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
):
):
# Parse args
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
self
.
mem_fraction_static
=
mem_fraction_static
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
...
@@ -256,7 +51,6 @@ class ModelRunner:
...
@@ -256,7 +51,6 @@ class ModelRunner:
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_dummy_weight_loader
()
# Init torch distributed
# Init torch distributed
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
...
@@ -287,11 +81,8 @@ class ModelRunner:
...
@@ -287,11 +81,8 @@ class ModelRunner:
)
)
# Set some global args
# Set some global args
global
global_server_args_dict
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
=
{
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Load the model and create memory pool
# Load the model and create memory pool
self
.
load_model
()
self
.
load_model
()
...
@@ -425,27 +216,6 @@ class ModelRunner:
...
@@ -425,27 +216,6 @@ class ModelRunner:
)
=
None
)
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
PREFILL
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_extend
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
...
@@ -523,8 +293,6 @@ class ModelRunner:
...
@@ -523,8 +293,6 @@ class ModelRunner:
return
self
.
forward_decode
(
batch
)
return
self
.
forward_decode
(
batch
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend
(
batch
)
return
self
.
forward_extend
(
batch
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
return
self
.
forward_prefill
(
batch
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment