Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
118b6af3
Unverified
Commit
118b6af3
authored
Dec 01, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 01, 2024
Browse files
feat: add should_use_tensor_core (#2179)
parent
9449a954
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
19 deletions
+65
-19
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+13
-15
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+48
-0
scripts/deprecated/test_flashinfer.py
scripts/deprecated/test_flashinfer.py
+4
-4
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
118b6af3
...
...
@@ -18,7 +18,11 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_bool_env_var
,
is_flashinfer_available
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_flashinfer_available
,
should_use_tensor_core
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -31,7 +35,6 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
WrapperDispatch
(
Enum
):
...
...
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
# Parse constants
if
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
in
os
.
environ
:
self
.
decode_use_tensor_cores
=
get_bool_env_var
(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else
:
if
not
_grouped_size_compiled_for_decode_kernels
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
):
self
.
decode_use_tensor_cores
=
True
else
:
self
.
decode_use_tensor_cores
=
False
self
.
decode_use_tensor_cores
=
should_use_tensor_core
(
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
,
num_attention_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
)
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
...
...
python/sglang/srt/utils.py
View file @
118b6af3
...
...
@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return
_cuda_device_count_stateless
(
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
))
def
should_use_tensor_core
(
kv_cache_dtype
:
torch
.
dtype
,
num_attention_heads
:
int
,
num_kv_heads
:
int
,
)
->
bool
:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override
=
os
.
environ
.
get
(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
if
env_override
is
not
None
:
return
env_override
.
lower
()
==
"true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
,
):
return
True
else
:
return
False
except
(
ImportError
,
AttributeError
):
pass
# Calculate GQA group size
gqa_group_size
=
num_attention_heads
//
num_kv_heads
# Determine based on dtype and GQA group size
if
kv_cache_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
):
return
True
elif
kv_cache_dtype
in
(
torch
.
float16
,
torch
.
half
,
torch
.
bfloat16
):
return
gqa_group_size
>
4
else
:
return
False
scripts/deprecated/test_flashinfer.py
View file @
118b6af3
...
...
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd
,
redundant_attention
,
)
from
sglang.srt.utils
import
should_use_tensor_core
flashinfer_prefill_wrapper
=
None
flashinfer_decode_wrapper
=
None
...
...
@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
def
init_flashinfer
(
num_attention_heads
,
num_kv_heads
):
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
use_tensor_cores
=
should_use_tensor_core
(
torch
.
half
,
num_attention_heads
,
num_kv_heads
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment