Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
6d0fa73e
"magic_pdf/vscode:/vscode.git/clone" did not exist on "c88ba5df1e6fcb90cc96b68bb2312855fd27d216"
Unverified
Commit
6d0fa73e
authored
Oct 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 17, 2024
Browse files
Simplify flashinfer utilities (#1704)
parent
9e0dac1a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
391 additions
and
337 deletions
+391
-337
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+27
-5
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+352
-83
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+0
-237
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-4
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
6d0fa73e
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
...
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
"""Init the metadata for a forward pass for replying a cuda graph."""
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
...
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run forward on an attention layer."""
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for decode."""
"""Run a forward for decode."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for extend."""
"""Run a forward for extend."""
raise
NotImplementedError
()
raise
NotImplementedError
()
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
6d0fa73e
...
@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_start_loc
,
...
@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
6d0fa73e
...
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
...
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
"""
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_utils
import
(
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
WrapperDispatch
,
update_flashinfer_indices
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -34,13 +33,18 @@ if is_flashinfer_available():
...
@@ -34,13 +33,18 @@ if is_flashinfer_available():
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
class
FlashInferAttnBackend
(
AttentionBackend
):
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
"""Flashinfer attention kernels."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
super
().
__init__
()
self
.
model_runner
=
model_runner
# Parse constants
if
not
_grouped_size_compiled_for_decode_kernels
(
if
not
_grouped_size_compiled_for_decode_kernels
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
...
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
decode_use_tensor_cores
=
True
self
.
decode_use_tensor_cores
=
True
else
:
else
:
self
.
decode_use_tensor_cores
=
False
self
.
decode_use_tensor_cores
=
False
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
assert
not
(
assert
not
(
model_runner
.
sliding_window_size
is
not
None
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
and
model_runner
.
has_cross_attention
),
"Sliding window and cross attention are not supported together"
),
"Sliding window and cross attention are not supported together"
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
if
model_runner
.
sliding_window_size
is
not
None
:
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
elif
model_runner
.
has_cross_attention
:
self
.
num_wrappers
=
2
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
else
:
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
# Allocate buffers
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
model_runner
.
device
,
)
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
self
.
kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
qo_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
# Create wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
prefill_wrapper_ragged
=
(
self
.
prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
...
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
)
)
# Create indices updater
self
.
indices_updater_decode
=
FlashInferIndicesUpdaterDecode
(
model_runner
,
self
)
self
.
indices_updater_prefill
=
FlashInferIndicesUpdaterPrefill
(
model_runner
,
self
)
# Other metadata
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
self
.
cuda_graph_metadata
=
{}
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
if
self
.
num_wrappers
==
1
:
return
0
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
prefix_lens
=
None
self
.
indices_updater_decode
.
update
(
use_ragged
=
False
forward_batch
.
req_pool_indices
,
extend_no_prefix
=
False
forward_batch
.
seq_lens
,
total_num_tokens
=
None
)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
else
:
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
prefix_lens
=
forward_batch
.
extend_prefix_lens
...
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
):
):
use_ragged
=
True
use_ragged
=
True
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
forward_batch
.
extend_prefix_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
forward_batch
.
extend_prefix_lens
).
item
()
update_flashinfer_indices
(
self
.
indices_updater_prefill
.
update
(
forward_batch
.
forward_mode
,
self
.
model_runner
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
prefix_lens
,
prefix_lens
,
use_ragged
=
use_ragged
,
use_ragged
,
)
)
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
use_ragged
,
use_ragged
,
extend_no_prefix
,
extend_no_prefix
,
total_num_tokens
,
self
.
decode_wrappers
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indptr
=
torch
.
zeros
(
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
max_bs
*
self
.
max_context_len
,),
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
self
.
cuda_graph_kv_last_page_len
=
torch
.
ones
(
self
.
cuda_graph_kv_indices
=
[
cuda_graph_kv_indices
]
+
[
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
)
# NOTE: the buffers are always in the form of list
self
.
cuda_graph_kv_indptr
=
[
self
.
cuda_graph_kv_indptr
]
+
[
self
.
cuda_graph_kv_indptr
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
self
.
cuda_graph_kv_indices
=
[
self
.
cuda_graph_kv_indices
]
+
[
self
.
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
]
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
decode_wrappers
=
[]
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
for
i
in
range
(
self
.
num_wrappers
):
...
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD"
,
"NHD"
,
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[:
bs
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
bs
],
)
)
)
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
decode_wrappers
,
)
)
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
decode_wrappers
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
(
decode_wrappers
,)
self
.
forward_metadata
=
(
False
,
False
,
None
,
decode_wrappers
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
update_flashinfer_indices
(
self
.
indices_updater_decode
.
update
(
ForwardMode
.
DECODE
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
self
.
cuda_graph_metadata
[
bs
]
self
.
model_runner
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
None
,
self
.
cuda_graph_metadata
[
bs
],
)
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
_get_wrapper_idx
(
layer
)
self
.
_get_wrapper_idx
(
layer
)
]
]
use_ragged
,
extend_no_prefix
,
_
,
_
=
self
.
forward_metadata
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
if
not
use_ragged
:
if
not
use_ragged
:
if
k
is
not
None
:
if
k
is
not
None
:
...
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
decode_wrapper
=
self
.
forward_metadata
[
-
1
][
self
.
_get_wrapper_idx
(
layer
)]
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
...
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
if
self
.
num_wrappers
==
1
:
return
0
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
class
FlashInferIndicesUpdaterDecode
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
decode_wrappers
=
None
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
decode_wrappers
[
0
],
req_pool_indices
,
seq_lens
,
self
.
kv_indptr
[
0
],
None
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
decode_wrappers
=
None
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# Sliding window attention
paged_kernel_lens
=
torch
.
minimum
(
# TODO: replace this with clamp
seq_lens
,
torch
.
tensor
(
self
.
sliding_window_size
+
1
),
)
else
:
# Full attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
seq_lens
-
paged_kernel_lens
self
.
call_begin_forward
(
decode_wrappers
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
call_begin_forward
(
self
,
wrapper
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
):
bs
=
len
(
req_pool_indices
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
# TODO: optimize the blocking call on kv_indptr[-1]
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
max_context_len
,
)
wrapper
.
end_forward
()
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
data_type
,
q_data_type
=
self
.
q_data_type
,
)
class
FlashInferIndicesUpdaterPrefill
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
wrapper_ragged
=
attn_backend
.
prefill_wrapper_ragged
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
0
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
None
,
self
.
kv_indptr
[
0
],
self
.
qo_indptr
[
0
],
use_ragged
,
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# window attention use paged only
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
self
.
sliding_window_size
)
+
seq_lens
-
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
seq_lens
-
paged_kernel_lens
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
call_begin_forward
(
self
,
wrapper_ragged
,
wrapper_paged
,
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
kv_indptr
,
qo_indptr
,
use_ragged
,
):
bs
=
len
(
req_pool_indices
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
max_context_len
,
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
# extend part
if
use_ragged
:
wrapper_ragged
.
end_forward
()
wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
)
# cached part
wrapper_paged
.
end_forward
()
wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
max_context_len
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
python/sglang/srt/layers/attention/flashinfer_utils.py
deleted
100644 → 0
View file @
9e0dac1a
from
enum
import
Enum
,
auto
import
torch
import
triton
import
triton.language
as
tl
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
max_context_len
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
class
FlashinferUpdater
:
def
__init__
(
self
,
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
=
None
,
use_ragged
=
False
,
):
self
.
forward_mode
=
forward_mode
self
.
model_runner
=
model_runner
self
.
req_pool_indices
=
req_pool_indices
self
.
seq_lens
=
seq_lens
self
.
prefix_lens
=
prefix_lens
self
.
use_ragged
=
use_ragged
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
batch_size
=
len
(
req_pool_indices
)
self
.
decode_wrappers
=
(
decode_wrappers
or
self
.
model_runner
.
attn_backend
.
decode_wrappers
)
self
.
prefill_wrapper_ragged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
)
self
.
prefill_wrappers_paged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrappers_paged
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
assert
not
isinstance
(
decode_wrapper
,
list
)
decode_wrapper
.
end_forward
()
decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
model_runner
.
kv_cache_dtype
,
q_data_type
=
self
.
model_runner
.
dtype
,
)
def
_update_extend_indices
(
self
,
ragged_wrapper
,
paged_wrapper
):
assert
not
isinstance
(
paged_wrapper
,
list
)
assert
not
isinstance
(
ragged_wrapper
,
list
)
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
-
self
.
prefix_lens
,
dim
=
0
)
if
self
.
use_ragged
:
ragged_wrapper
.
end_forward
()
ragged_wrapper
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
)
# cached part
paged_wrapper
.
end_forward
()
paged_wrapper
.
begin_forward
(
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
)
def
_get_indices
(
self
,
dispatch_reason
:
WrapperDispatch
=
None
,
wrapper_id
=
0
):
if
dispatch_reason
is
None
:
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
None
elif
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
self
.
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_indicess_single_wrapper
(
self
):
self
.
_get_indices
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
0
])
else
:
self
.
_update_extend_indices
(
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrappers_paged
[
0
],
)
def
_update_indices_cross_attention
(
self
):
pass
def
_update_indices_sliding_window
(
self
):
assert
self
.
use_ragged
is
False
for
wrapper_id
in
range
(
2
):
self
.
_get_indices
(
WrapperDispatch
.
SLIDING_WINDOW
,
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
wrapper_id
])
else
:
self
.
_update_extend_indices
(
None
,
self
.
prefill_wrappers_paged
[
wrapper_id
],
)
def
update_flashinfer_indices
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
=
None
,
use_ragged
=
False
,
):
updater
=
FlashinferUpdater
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
,
use_ragged
,
)
dispatch_reason
=
model_runner
.
attn_backend
.
dispatch_reason
if
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
updater
.
_update_indices_sliding_window
()
elif
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
updater
.
_update_indices_cross_attention
()
else
:
assert
model_runner
.
attn_backend
.
num_wrappers
==
1
updater
.
_update_indicess_single_wrapper
()
python/sglang/srt/layers/attention/triton_backend.py
View file @
6d0fa73e
...
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_start_loc
,
...
@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
6d0fa73e
...
@@ -744,7 +744,6 @@ class ScheduleBatch:
...
@@ -744,7 +744,6 @@ class ScheduleBatch:
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
self
.
seq_lens
.
add_
(
1
)
self
.
output_ids
=
None
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
...
@@ -755,9 +754,10 @@ class ScheduleBatch:
...
@@ -755,9 +754,10 @@ class ScheduleBatch:
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
req_pool_indices
,
self
.
seq_lens
-
1
self
.
out_cache_loc
]
=
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
def
filter_batch
(
def
filter_batch
(
self
,
self
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
6d0fa73e
...
@@ -134,9 +134,7 @@ class ForwardBatch:
...
@@ -134,9 +134,7 @@ class ForwardBatch:
)
)
# Init position information
# Init position information
if
ret
.
forward_mode
.
is_decode
():
if
not
ret
.
forward_mode
.
is_decode
():
ret
.
positions
=
(
ret
.
seq_lens
-
1
).
to
(
torch
.
int64
)
else
:
ret
.
positions
=
torch
.
tensor
(
ret
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
...
@@ -164,7 +162,6 @@ class ForwardBatch:
...
@@ -164,7 +162,6 @@ class ForwardBatch:
ret
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
ret
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
ret
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
ret
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
ret
.
attn_backend
=
model_runner
.
attn_backend
ret
.
attn_backend
=
model_runner
.
attn_backend
model_runner
.
attn_backend
.
init_forward_metadata
(
ret
)
# Init lora information
# Init lora information
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
6d0fa73e
...
@@ -551,11 +551,14 @@ class ModelRunner:
...
@@ -551,11 +551,14 @@ class ModelRunner:
):
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
if
self
.
is_generation
:
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment