Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
519e20cf
"vscode:/vscode.git/clone" did not exist on "968c52dd7bde8325e85fc53fa7f38491f1399656"
Unverified
Commit
519e20cf
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609)
parent
d9a69029
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
219 additions
and
245 deletions
+219
-245
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-5
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+212
-3
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+5
-237
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
519e20cf
...
@@ -8,6 +8,7 @@ from torch import nn
...
@@ -8,6 +8,7 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.controller.infer_batch
import
global_server_args_dict
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
...
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
...
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
self
.
scaling
=
scaling
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
from
sglang.srt.managers.controller.model_runner
import
global_server_args_dict
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
...
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
...
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
PREFILL
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
prefill_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
519e20cf
...
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Store some global server args
global_server_args_dict
=
{}
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL
=
auto
()
PREFILL
=
auto
()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND
=
auto
()
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
DECODE
=
auto
()
...
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
class
Req
:
class
Req
:
"""Store all inforamtion of a request."""
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
self
.
rid
=
rid
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_text
=
origin_input_text
...
@@ -74,7 +82,7 @@ class Req:
...
@@ -74,7 +82,7 @@ class Req:
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# For incremental decod
e
# For incremental decod
ing
self
.
decoded_text
=
""
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
self
.
read_offset
=
None
...
@@ -93,9 +101,8 @@ class Req:
...
@@ -93,9 +101,8 @@ class Req:
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
stream
=
False
self
.
stream
=
False
self
.
tokenizer
=
None
# Check finish
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
# Prefix info
# Prefix info
...
@@ -252,6 +259,8 @@ class Req:
...
@@ -252,6 +259,8 @@ class Req:
@
dataclass
@
dataclass
class
Batch
:
class
Batch
:
"""Store all inforamtion of a batch."""
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
...
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
...
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
]
=
0.0
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
return
probs_sort
,
probs_idx
return
probs_sort
,
probs_idx
@
dataclass
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
# for extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
max_extend_len
:
int
=
0
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
qo_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
self
.
forward_mode
==
ForwardMode
.
EXTEND
:
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
(),
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
def
create
(
cls
,
model_runner
,
tp_size
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
other_kv_index
=
None
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
python/sglang/srt/managers/controller/model_runner.py
View file @
519e20cf
...
@@ -4,11 +4,9 @@ import importlib
...
@@ -4,11 +4,9 @@ import importlib
import
importlib.resources
import
importlib.resources
import
logging
import
logging
import
pkgutil
import
pkgutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Type
from
typing
import
Optional
,
Type
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
...
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
...
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
...
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
logger
=
logging
.
getLogger
(
"srt.model_runner"
)
# for server args in model endpoints
global_server_args_dict
=
{}
@
dataclass
class
InputMetadata
:
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
# for extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
max_extend_len
:
int
=
0
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
qo_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
):
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
(),
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
def
create
(
cls
,
model_runner
,
tp_size
,
forward_mode
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
position_ids_offsets
,
out_cache_loc
,
out_cache_cont_start
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
position_ids_offsets_cpu
=
position_ids_offsets
.
cpu
().
numpy
()
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
np
.
arange
(
prefix_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
seq_lens_cpu
[
i
]
+
position_ids_offsets_cpu
[
i
],
)
for
i
in
range
(
batch_size
)
],
axis
=
0
,
),
device
=
"cuda"
,
)
other_kv_index
=
None
ret
=
cls
(
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
class
ModelRunner
:
class
ModelRunner
:
def
__init__
(
def
__init__
(
...
@@ -245,6 +39,7 @@ class ModelRunner:
...
@@ -245,6 +39,7 @@ class ModelRunner:
nccl_port
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
):
):
# Parse args
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
self
.
mem_fraction_static
=
mem_fraction_static
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
...
@@ -256,7 +51,6 @@ class ModelRunner:
...
@@ -256,7 +51,6 @@ class ModelRunner:
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_dummy_weight_loader
()
# Init torch distributed
# Init torch distributed
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
...
@@ -287,11 +81,8 @@ class ModelRunner:
...
@@ -287,11 +81,8 @@ class ModelRunner:
)
)
# Set some global args
# Set some global args
global
global_server_args_dict
global_server_args_dict
[
"disable_flashinfer"
]
=
server_args
.
disable_flashinfer
global_server_args_dict
=
{
global_server_args_dict
[
"attention_reduce_in_fp32"
]
=
server_args
.
attention_reduce_in_fp32
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Load the model and create memory pool
# Load the model and create memory pool
self
.
load_model
()
self
.
load_model
()
...
@@ -425,27 +216,6 @@ class ModelRunner:
...
@@ -425,27 +216,6 @@ class ModelRunner:
)
=
None
)
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
PREFILL
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
position_ids_offsets
=
batch
.
position_ids_offsets
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_extend
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
...
@@ -523,8 +293,6 @@ class ModelRunner:
...
@@ -523,8 +293,6 @@ class ModelRunner:
return
self
.
forward_decode
(
batch
)
return
self
.
forward_decode
(
batch
)
elif
forward_mode
==
ForwardMode
.
EXTEND
:
elif
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
forward_extend
(
batch
)
return
self
.
forward_extend
(
batch
)
elif
forward_mode
==
ForwardMode
.
PREFILL
:
return
self
.
forward_prefill
(
batch
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment