Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
84967019
Commit
84967019
authored
Dec 22, 2024
by
Lianmin Zheng
Browse files
[Misc] Fix metrics, weight update lock, request logging (#2543)
parent
7d672d27
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
412 additions
and
315 deletions
+412
-315
docs/references/production_metrics.md
docs/references/production_metrics.md
+110
-170
python/sglang/srt/aio_rwlock.py
python/sglang/srt/aio_rwlock.py
+94
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+4
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+49
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-10
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+86
-76
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-8
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+33
-44
test/srt/test_metrics.py
test/srt/test_metrics.py
+2
-0
No files found.
docs/references/production_metrics.md
View file @
84967019
This diff is collapsed.
Click to expand it.
python/sglang/srt/aio_rwlock.py
0 → 100644
View file @
84967019
import
asyncio
class
RWLock
:
"""
A Read-Write Lock for asyncio:
- Multiple readers can hold the lock in parallel if no writer holds it.
- A writer has exclusive access.
"""
def
__init__
(
self
):
self
.
_readers
=
0
# How many readers currently hold the lock
self
.
_writer_active
=
False
self
.
_lock
=
asyncio
.
Lock
()
# Internal mutex to protect state
# Conditions associated with _lock:
self
.
_readers_ok
=
asyncio
.
Condition
(
self
.
_lock
)
# Notify blocked readers
self
.
_writers_ok
=
asyncio
.
Condition
(
self
.
_lock
)
# Notify blocked writers
# Expose two async context-manager helpers:
self
.
reader_lock
=
self
.
_ReaderLock
(
self
)
self
.
writer_lock
=
self
.
_WriterLock
(
self
)
async
def
_acquire_reader
(
self
):
"""
Wait until there is no active writer.
Then increment the count of active readers.
"""
async
with
self
.
_lock
:
# If a writer is active, wait until it's done.
while
self
.
_writer_active
:
await
self
.
_readers_ok
.
wait
()
self
.
_readers
+=
1
async
def
_release_reader
(
self
):
"""
Decrement the count of active readers.
If this was the last active reader, wake up a possible waiting writer.
"""
async
with
self
.
_lock
:
self
.
_readers
-=
1
# If no more readers, a writer could proceed.
if
self
.
_readers
==
0
:
self
.
_writers_ok
.
notify
()
async
def
_acquire_writer
(
self
):
"""
Wait until there is no active writer and no active readers.
Then mark a writer as active.
"""
async
with
self
.
_lock
:
while
self
.
_writer_active
or
self
.
_readers
>
0
:
await
self
.
_writers_ok
.
wait
()
self
.
_writer_active
=
True
async
def
_release_writer
(
self
):
"""
Mark the writer as done and notify readers and writers.
"""
async
with
self
.
_lock
:
self
.
_writer_active
=
False
# Allow any waiting readers to proceed:
self
.
_readers_ok
.
notify_all
()
# Allow next waiting writer to proceed:
self
.
_writers_ok
.
notify
()
class
_ReaderLock
:
"""
A simple async context manager that acquires a reader lock
on entering and releases it on exit.
"""
def
__init__
(
self
,
parent
:
"RWLock"
):
self
.
_parent
=
parent
async
def
__aenter__
(
self
):
await
self
.
_parent
.
_acquire_reader
()
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
await
self
.
_parent
.
_release_reader
()
class
_WriterLock
:
"""
A simple async context manager that acquires a writer lock
on entering and releases it on exit.
"""
def
__init__
(
self
,
parent
:
"RWLock"
):
self
.
_parent
=
parent
async
def
__aenter__
(
self
):
await
self
.
_parent
.
_acquire_writer
()
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
await
self
.
_parent
.
_release_writer
()
python/sglang/srt/configs/model_config.py
View file @
84967019
...
...
@@ -124,8 +124,12 @@ class ModelConfig:
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
# Veirfy quantization
self
.
_verify_quantization
()
# Multimodel attrs
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
84967019
...
...
@@ -18,11 +18,7 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_flashinfer_available
,
should_use_tensor_core
,
)
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
def
should_use_tensor_core
(
kv_cache_dtype
:
torch
.
dtype
,
num_attention_heads
:
int
,
num_kv_heads
:
int
,
)
->
bool
:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override
=
os
.
environ
.
get
(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
if
env_override
is
not
None
:
return
env_override
.
lower
()
==
"true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
,
):
return
True
else
:
return
False
except
(
ImportError
,
AttributeError
):
pass
# Calculate GQA group size
gqa_group_size
=
num_attention_heads
//
num_kv_heads
# Determine based on dtype and GQA group size
if
kv_cache_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
):
return
True
elif
kv_cache_dtype
in
(
torch
.
float16
,
torch
.
half
,
torch
.
bfloat16
):
return
gqa_group_size
>
4
else
:
return
False
python/sglang/srt/managers/schedule_batch.py
View file @
84967019
...
...
@@ -479,8 +479,22 @@ class Req:
return
True
def
reset_for_retract
(
self
):
self
.
prefix_indices
=
[]
self
.
last_node
=
None
self
.
extend_input_len
=
0
self
.
is_retracted
=
True
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self
.
last_update_decode_tokens
=
0
self
.
logprob_start_len
=
10
**
9
def
__repr__
(
self
):
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, "
return
(
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, output_ids=
{
self
.
output_ids
}
"
)
bid
=
0
...
...
@@ -894,15 +908,7 @@ class ScheduleBatch:
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
req
.
prefix_indices
=
[]
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
is_retracted
=
True
# For incremental logprobs
req
.
last_update_decode_tokens
=
0
req
.
logprob_start_len
=
10
**
9
req
.
reset_for_retract
()
self
.
filter_batch
(
keep_indices
=
sorted_indices
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
84967019
...
...
@@ -22,7 +22,7 @@ import warnings
from
collections
import
deque
from
concurrent
import
futures
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
psutil
import
setproctitle
...
...
@@ -260,7 +260,7 @@ class Scheduler:
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
# Session info
self
.
sessions
=
{}
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
84967019
...
...
@@ -22,7 +22,7 @@ import signal
import
sys
import
time
import
uuid
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
import
uvloop
...
...
@@ -30,6 +30,7 @@ import zmq
import
zmq.asyncio
from
fastapi
import
BackgroundTasks
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.image_processor
import
(
...
...
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_zmq_socket
,
kill_process_tree
from
sglang.srt.utils
import
(
dataclass_to_string_truncated
,
get_zmq_socket
,
kill_process_tree
,
)
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -82,6 +87,9 @@ class ReqState:
created_time
:
float
first_token_time
:
Optional
[
float
]
=
None
# For streaming output
last_output_offset
:
int
=
0
class
TokenizerManager
:
"""TokenizerManager is a process that tokenizes the text."""
...
...
@@ -120,6 +128,7 @@ class TokenizerManager:
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
context_len
=
self
.
model_config
.
context_len
self
.
image_token_id
=
self
.
model_config
.
image_token_id
# Create image processor placeholder
self
.
image_processor
=
get_dummy_image_processor
()
...
...
@@ -152,9 +161,12 @@ class TokenizerManager:
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
# For update model weights
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
self
.
model_update_result
:
Optional
[
Awaitable
[
UpdateWeightFromDiskReqOutput
]]
=
(
None
)
self
.
asyncio_tasks
=
set
()
# For session info
self
.
session_futures
=
{}
# session_id -> asyncio event
...
...
@@ -181,9 +193,6 @@ class TokenizerManager:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
while
self
.
model_update_lock
.
locked
():
await
asyncio
.
sleep
(
0.001
)
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. "
...
...
@@ -191,17 +200,24 @@ class TokenizerManager:
)
obj
.
normalize_batch_and_arguments
()
is_single
=
obj
.
is_single
if
is_single
:
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
request
,
created_time
):
yield
response
else
:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
,
created_time
):
yield
response
if
self
.
server_args
.
log_requests
:
logger
.
info
(
f
"Receive: obj=
{
dataclass_to_string_truncated
(
obj
)
}
"
)
async
with
self
.
model_update_lock
.
reader_lock
:
is_single
=
obj
.
is_single
if
is_single
:
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
request
,
created_time
):
yield
response
else
:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
,
created_time
):
yield
response
async
def
_tokenize_one_request
(
self
,
...
...
@@ -215,7 +231,7 @@ class TokenizerManager:
if
not
self
.
server_args
.
disable_radix_cache
:
raise
ValueError
(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"Please add `--disable-radix-cach
e
` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds
=
obj
.
input_embeds
...
...
@@ -301,8 +317,8 @@ class TokenizerManager:
state
.
out_list
=
[]
if
state
.
finished
:
if
self
.
server_args
.
log_requests
:
# Log requests
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
msg
=
f
"Finish: obj=
{
dataclass_to_string_truncated
(
obj
)
}
, out=
{
dataclass_to_string_truncated
(
out
)
}
"
logger
.
info
(
msg
)
del
self
.
rid_to_state
[
obj
.
rid
]
yield
out
break
...
...
@@ -423,55 +439,52 @@ class TokenizerManager:
self
,
obj
:
UpdateWeightFromDiskReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
)
->
Tuple
[
bool
,
str
]
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
# default the load format to the server_args
if
obj
.
load_format
is
None
:
obj
.
load_format
=
self
.
server_args
.
load_format
logger
.
info
(
"Start update_weights. Load format=%s"
,
obj
.
load_format
)
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
for
i
in
range
(
3
):
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
# FIXME: We add some sleep here to avoid some race conditions.
# We can use a read-write lock as a better fix.
await
asyncio
.
sleep
(
0.01
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
if
True
:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
return
await
self
.
_wait_for_model_update_from_disk
(
obj
)
if
self
.
server_args
.
dp_size
==
1
:
result
=
await
self
.
model_update_result
if
result
.
success
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
=
[]
result
=
await
self
.
model_update_result
all_success
=
all
([
r
.
success
for
r
in
result
])
if
all_success
is
True
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
all_message
=
[
r
.
message
for
r
in
result
]
all_message
=
" | "
.
join
(
all_message
)
return
all_success
,
all_message
else
:
return
False
,
"Another update is in progress. Please try again later."
async
def
_wait_for_model_update_from_disk
(
self
,
obj
:
UpdateWeightFromDiskReqInput
)
->
Tuple
[
bool
,
str
,
int
]:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
if
self
.
server_args
.
dp_size
==
1
:
result
=
await
self
.
model_update_result
if
result
.
success
:
self
.
served_model_name
=
obj
.
model_path
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
=
[]
result
=
await
self
.
model_update_result
all_success
=
all
([
r
.
success
for
r
in
result
])
if
all_success
is
True
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
all_message
=
[
r
.
message
for
r
in
result
]
all_message
=
" | "
.
join
(
all_message
)
return
all_success
,
all_message
async
def
init_weights_update_group
(
self
,
obj
:
InitWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
bool
:
)
->
Tuple
[
bool
,
str
]
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
...
...
@@ -487,25 +500,22 @@ class TokenizerManager:
self
,
obj
:
UpdateWeightsFromDistributedReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
)
->
Tuple
[
bool
,
str
]
:
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
parameter_update_result
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
else
:
logger
.
error
(
"Another parameter update is in progress in tokenizer manager"
)
return
(
False
,
"Another parameter update is in progress. Please try again later."
,
)
# This means that weight sync
# cannot run while requests are in progress.
async
with
self
.
model_update_lock
.
writer_lock
:
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
parameter_update_result
:
Awaitable
[
UpdateWeightsFromDistributedReqOutput
]
=
asyncio
.
Future
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be for update weights from distributed"
result
=
await
self
.
parameter_update_result
return
result
.
success
,
result
.
message
async
def
get_weights_by_name
(
self
,
obj
:
GetWeightsByNameReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
...
...
@@ -564,11 +574,11 @@ class TokenizerManager:
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
loop
.
create_task
(
self
.
handle_loop
())
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
self
.
handle_loop
())
)
signal_handler
=
SignalHandler
(
self
)
loop
.
add_signal_handler
(
signal
.
SIGTERM
,
signal_handler
.
signal_handler
)
loop
.
create_task
(
self
.
sigterm_watchdog
())
self
.
asyncio_tasks
.
add
(
loop
.
create_task
(
self
.
sigterm_watchdog
())
)
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
84967019
...
...
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device
:
str
,
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
head_num
=
head_num
self
.
head_dim
=
head_dim
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
def
_create_buffers
(
self
):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
k_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
device
,
device
=
self
.
device
,
)
for
_
in
range
(
layer_num
)
for
_
in
range
(
self
.
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
device
,
device
=
self
.
device
,
)
for
_
in
range
(
layer_num
)
for
_
in
range
(
self
.
layer_num
)
]
def
_clear_buffers
(
self
):
del
self
.
k_buffer
del
self
.
v_buffer
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
...
...
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
...
...
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
...
...
python/sglang/srt/server.py
View file @
84967019
...
...
@@ -311,6 +311,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
except
ValueError
as
e
:
logger
.
error
(
f
"Error:
{
e
}
"
)
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
...
...
python/sglang/srt/utils.py
View file @
84967019
...
...
@@ -14,6 +14,7 @@
"""Common utilities."""
import
base64
import
dataclasses
import
ipaddress
import
itertools
import
json
...
...
@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
return
_cuda_device_count_stateless
(
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
))
def
should_use_tensor_core
(
kv_cache_dtype
:
torch
.
dtype
,
num_attention_heads
:
int
,
num_kv_heads
:
int
,
)
->
bool
:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override
=
os
.
environ
.
get
(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
if
env_override
is
not
None
:
return
env_override
.
lower
()
==
"true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
,
):
return
True
def
dataclass_to_string_truncated
(
data
,
max_length
=
2048
):
if
isinstance
(
data
,
str
):
if
len
(
data
)
>
max_length
:
half_length
=
max_length
//
2
return
f
'"
{
data
[:
half_length
]
}
...
{
data
[
-
half_length
:]
}
"'
else
:
return
False
except
(
ImportError
,
AttributeError
):
pass
# Calculate GQA group size
gqa_group_size
=
num_attention_heads
//
num_kv_heads
# Determine based on dtype and GQA group size
if
kv_cache_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
):
return
True
elif
kv_cache_dtype
in
(
torch
.
float16
,
torch
.
half
,
torch
.
bfloat16
):
return
gqa_group_size
>
4
return
f
'"
{
data
}
"'
elif
isinstance
(
data
,
(
list
,
tuple
)):
if
len
(
data
)
>
max_length
:
half_length
=
max_length
//
2
return
str
(
data
[:
half_length
])
+
" ... "
+
str
(
data
[
-
half_length
:])
else
:
return
str
(
data
)
elif
isinstance
(
data
,
dict
):
return
(
"{"
+
", "
.
join
(
f
"
{
k
}
:
{
dataclass_to_string_truncated
(
v
,
max_length
)
}
"
for
k
,
v
in
data
.
items
()
)
+
"}"
)
elif
dataclasses
.
is_dataclass
(
data
):
fields
=
dataclasses
.
fields
(
data
)
return
(
f
"
{
data
.
__class__
.
__name__
}
("
+
", "
.
join
(
f
"
{
f
.
name
}
=
{
dataclass_to_string_truncated
(
getattr
(
data
,
f
.
name
),
max_length
)
}
"
for
f
in
fields
)
+
")"
)
else
:
return
False
return
str
(
data
)
test/srt/test_metrics.py
View file @
84967019
...
...
@@ -51,8 +51,10 @@ class TestEnableMetrics(unittest.TestCase):
# Verify essential metrics are present
essential_metrics
=
[
"sglang:num_running_reqs"
,
"sglang:num_used_tokens"
,
"sglang:token_usage"
,
"sglang:gen_throughput"
,
"sglang:num_queue_reqs"
,
"sglang:cache_hit_rate"
,
"sglang:func_latency_seconds"
,
"sglang:prompt_tokens_total"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment