Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3a6e8b6d
Unverified
Commit
3a6e8b6d
authored
Sep 10, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 10, 2024
Browse files
[Minor] move triton attention kernels into a separate folder (#1379)
parent
fbb4754c
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
24 additions
and
15 deletions
+24
-15
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+0
-0
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+2
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+1
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+11
-4
python/sglang/srt/layers/triton_attention/decode_attention.py
...on/sglang/srt/layers/triton_attention/decode_attention.py
+0
-0
python/sglang/srt/layers/triton_attention/extend_attention.py
...on/sglang/srt/layers/triton_attention/extend_attention.py
+1
-1
python/sglang/srt/layers/triton_attention/prefill_attention.py
...n/sglang/srt/layers/triton_attention/prefill_attention.py
+0
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-2
scripts/deprecated/test_flashinfer.py
scripts/deprecated/test_flashinfer.py
+4
-1
No files found.
python/sglang/bench_latency.py
View file @
3a6e8b6d
...
@@ -57,9 +57,9 @@ import pandas as pd
...
@@ -57,9 +57,9 @@ import pandas as pd
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
...
python/sglang/srt/model_config.py
→
python/sglang/srt/
configs/
model_config.py
View file @
3a6e8b6d
File moved
python/sglang/srt/constrained/__init__.py
View file @
3a6e8b6d
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""For constrained decoding."""
import
json
import
json
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
...
...
python/sglang/srt/conversation.py
View file @
3a6e8b6d
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Conversation templates."""
"""Conversation
chat
templates."""
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
3a6e8b6d
...
@@ -16,11 +16,9 @@ limitations under the License.
...
@@ -16,11 +16,9 @@ limitations under the License.
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
functools
import
json
import
os
import
os
import
warnings
import
warnings
from
typing
import
AbstractSet
,
Collection
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
Union
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
from
transformers
import
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
3a6e8b6d
...
@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
...
@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
from
torch
import
nn
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.
triton_attention.
decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.
triton_attention.
extend_attention
import
extend_attention_fwd
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
...
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
self
.
sliding_window_size
=
sliding_window_size
if
sliding_window_size
else
-
1
self
.
sliding_window_size
=
sliding_window_size
if
sliding_window_size
else
-
1
# Choose backend
if
(
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
and
self
.
qk_head_dim
==
self
.
v_head_dim
and
self
.
qk_head_dim
==
self
.
v_head_dim
...
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
...
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
...
...
python/sglang/srt/layers/decode_attention.py
→
python/sglang/srt/layers/
triton_attention/
decode_attention.py
View file @
3a6e8b6d
File moved
python/sglang/srt/layers/extend_attention.py
→
python/sglang/srt/layers/
triton_attention/
extend_attention.py
View file @
3a6e8b6d
...
@@ -22,7 +22,7 @@ import torch
...
@@ -22,7 +22,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.
triton_attention.
prefill_attention
import
context_attention_fwd
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/prefill_attention.py
→
python/sglang/srt/layers/
triton_attention/
prefill_attention.py
View file @
3a6e8b6d
File moved
python/sglang/srt/managers/tp_worker.py
View file @
3a6e8b6d
...
@@ -29,6 +29,7 @@ import torch.distributed
...
@@ -29,6 +29,7 @@ import torch.distributed
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
...
@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
)
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3a6e8b6d
...
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
...
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""M
odelRunner runs the forward passes of the model
s."""
"""M
eta data for a forward pas
s."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
3a6e8b6d
...
@@ -18,7 +18,6 @@ limitations under the License.
...
@@ -18,7 +18,6 @@ limitations under the License.
import
gc
import
gc
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
json
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
...
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
...
@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
)
)
from
sglang.srt.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
...
scripts/deprecated/test_flashinfer.py
View file @
3a6e8b6d
...
@@ -6,8 +6,11 @@ from flashinfer import (
...
@@ -6,8 +6,11 @@ from flashinfer import (
)
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
,
redundant_attention
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.triton_attention.extend_attention
import
(
extend_attention_fwd
,
redundant_attention
,
)
flashinfer_prefill_wrapper
=
None
flashinfer_prefill_wrapper
=
None
flashinfer_decode_wrapper
=
None
flashinfer_decode_wrapper
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment