Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a78d8f8d
Unverified
Commit
a78d8f8d
authored
Nov 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 23, 2024
Browse files
[CI] Fix test cases (#2137)
parent
c5f86501
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
21 deletions
+27
-21
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+4
-2
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-9
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+9
-9
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
a78d8f8d
...
@@ -24,6 +24,8 @@ import triton.language as tl
...
@@ -24,6 +24,8 @@ import triton.language as tl
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
...
@@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd(
num_warps
=
4
num_warps
=
4
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip
()
:
if
is_hip
_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
...
@@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip
()
:
if
is_hip
_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
a78d8f8d
...
@@ -29,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
...
@@ -29,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
if
is_cuda_available
:
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_hip_
=
is_hip
()
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
...
@@ -311,7 +313,7 @@ def extend_attention_fwd(
...
@@ -311,7 +313,7 @@ def extend_attention_fwd(
num_stages
=
1
num_stages
=
1
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip
()
:
if
is_hip
_
:
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a78d8f8d
...
@@ -242,15 +242,17 @@ class ModelRunner:
...
@@ -242,15 +242,17 @@ class ModelRunner:
)
)
return
get_model
(
vllm_config
=
vllm_config
)
return
get_model
(
vllm_config
=
vllm_config
)
except
ImportError
:
except
ImportError
:
return
get_model
(
pass
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
return
get_model
(
device_config
=
DeviceConfig
(
self
.
device
),
model_config
=
self
.
vllm_model_config
,
parallel_config
=
None
,
load_config
=
self
.
load_config
,
scheduler_config
=
None
,
device_config
=
DeviceConfig
(
self
.
device
),
lora_config
=
None
,
parallel_config
=
None
,
cache_config
=
None
,
scheduler_config
=
None
,
)
lora_config
=
None
,
cache_config
=
None
,
)
def
get_model_config_params
(
self
):
def
get_model_config_params
(
self
):
sig
=
inspect
.
signature
(
VllmModelConfig
.
__init__
)
sig
=
inspect
.
signature
(
VllmModelConfig
.
__init__
)
...
...
test/srt/test_srt_engine.py
View file @
a78d8f8d
...
@@ -152,15 +152,7 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -152,15 +152,7 @@ class TestSRTEngine(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-5
,
rtol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-5
,
rtol
=
1e-3
))
def
test_7_engine_offline_throughput
(
self
):
def
test_7_engine_cpu_offload
(
self
):
server_args
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3500
)
def
test_8_engine_cpu_offload
(
self
):
prompt
=
"Today is a sunny day and I like"
prompt
=
"Today is a sunny day and I like"
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
@@ -190,6 +182,14 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -190,6 +182,14 @@ class TestSRTEngine(unittest.TestCase):
print
(
out2
)
print
(
out2
)
self
.
assertEqual
(
out1
,
out2
)
self
.
assertEqual
(
out1
,
out2
)
def
test_8_engine_offline_throughput
(
self
):
server_args
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3500
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment