Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
df7fe452
Unverified
Commit
df7fe452
authored
Nov 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 17, 2024
Browse files
Crash the CI jobs on model import errors (#2072)
parent
a7164b62
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
25 deletions
+30
-25
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+3
-6
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-13
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-1
python/sglang/srt/models/phi3_small.py
python/sglang/srt/models/phi3_small.py
+5
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-0
No files found.
python/sglang/srt/layers/sampler.py
View file @
df7fe452
...
@@ -8,7 +8,7 @@ from torch import nn
...
@@ -8,7 +8,7 @@ from torch import nn
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
crash_on_warnings
,
is_flashinfer_available
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer.sampling
import
(
from
flashinfer.sampling
import
(
...
@@ -19,10 +19,6 @@ if is_flashinfer_available():
...
@@ -19,10 +19,6 @@ if is_flashinfer_available():
)
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
...
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits
=
torch
.
where
(
logits
=
torch
.
where
(
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
)
)
exit
(
1
)
if
crash_on_warning
else
None
if
crash_on_warnings
():
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
if
sampling_info
.
is_all_greedy
:
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
# Use torch.argmax if all requests use greedy sampling
...
...
python/sglang/srt/managers/scheduler.py
View file @
df7fe452
...
@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
...
@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
broadcast_pyobj
,
broadcast_pyobj
,
configure_logger
,
configure_logger
,
crash_on_warnings
,
get_zmq_socket
,
get_zmq_socket
,
kill_parent_process
,
kill_parent_process
,
set_random_seed
,
set_random_seed
,
...
@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
...
@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
# Test retract decode
# Test retract decode
test_retract
=
os
.
getenv
(
"SGLANG_TEST_RETRACT"
,
"false"
)
==
"true"
test_retract
=
os
.
getenv
(
"SGLANG_TEST_RETRACT"
,
"false"
)
==
"true"
...
@@ -662,21 +659,23 @@ class Scheduler:
...
@@ -662,21 +659,23 @@ class Scheduler:
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
if
available_size
!=
self
.
max_total_num_tokens
:
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
msg
=
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
"KV cache pool leak detected!"
f
"
{
available_size
=
}
,
{
self
.
max_total_num_tokens
=
}
\n
"
)
)
exit
(
1
)
if
crash_on_warning
else
None
warnings
.
warn
(
msg
)
if
crash_on_warnings
():
raise
ValueError
(
msg
)
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
msg
=
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
"Memory pool leak detected!"
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
)
)
exit
(
1
)
if
crash_on_warning
else
None
warnings
.
warn
(
msg
)
if
crash_on_warnings
():
raise
ValueError
(
msg
)
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
):
# Merge the prefill batch into the running batch
# Merge the prefill batch into the running batch
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
df7fe452
...
@@ -20,6 +20,7 @@ import importlib
...
@@ -20,6 +20,7 @@ import importlib
import
importlib.resources
import
importlib.resources
import
json
import
json
import
logging
import
logging
import
os
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
from
typing
import
Optional
,
Type
...
@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
crash_on_warnings
,
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
...
@@ -665,7 +667,9 @@ def import_model_classes():
...
@@ -665,7 +667,9 @@ def import_model_classes():
try
:
try
:
module
=
importlib
.
import_module
(
name
)
module
=
importlib
.
import_module
(
name
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning
(
f
"Ignore import error when loading
{
name
}
. "
f
"
{
e
}
"
)
logger
.
warning
(
f
"Ignore import error when loading
{
name
}
.
{
e
}
"
)
if
crash_on_warnings
():
raise
ValueError
(
f
"Ignore import error when loading
{
name
}
.
{
e
}
"
)
continue
continue
if
hasattr
(
module
,
"EntryClass"
):
if
hasattr
(
module
,
"EntryClass"
):
entry
=
module
.
EntryClass
entry
=
module
.
EntryClass
...
...
python/sglang/srt/models/phi3_small.py
View file @
df7fe452
import
math
import
math
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Phi3Config
from
transformers
import
Phi3Config
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
make_layers
,
maybe_prefix
from
vllm.model_executor.models.utils
import
make_layers
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
...
@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
self
,
self
,
config
:
Phi3Config
,
config
:
Phi3Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
cache_config
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
...
@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
self
.
model
=
Phi3SmallModel
(
self
.
model
=
Phi3SmallModel
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
,
prefix
=
"model"
,
)
)
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
...
python/sglang/srt/utils.py
View file @
df7fe452
...
@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
...
@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
raise
RuntimeError
(
raise
RuntimeError
(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
)
)
def
crash_on_warnings
():
# Crash on warning if we are running CI tests
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment