Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
c1f401fc
You need to sign in or sign up before continuing.
Unverified
Commit
c1f401fc
authored
Nov 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 17, 2024
Browse files
Revert "chore: update torch v2.5.1" (#2063)
parent
3b878863
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
37 additions
and
174 deletions
+37
-174
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+0
-2
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+0
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-14
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+21
-61
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+1
-7
test/srt/test_nightly_gsm8k_eval.py
test/srt/test_nightly_gsm8k_eval.py
+4
-84
test/srt/test_torch_compile.py
test/srt/test_torch_compile.py
+1
-1
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+1
-1
No files found.
.github/workflows/pr-test.yml
View file @
c1f401fc
...
@@ -47,7 +47,7 @@ jobs:
...
@@ -47,7 +47,7 @@ jobs:
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
30
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
...
...
python/pyproject.toml
View file @
c1f401fc
...
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
...
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"orjson"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"pyzmq>=25.1.2"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"pyzmq>=25.1.2"
,
"outlines>=0.0.44,<0.1.0"
,
"modelscope"
]
"outlines>=0.0.44,<0.1.0"
,
"modelscope"
]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.6.
4
.post
1
"]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.6.
3
.post
1
"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
...
...
python/sglang/srt/layers/activation.py
View file @
c1f401fc
...
@@ -38,7 +38,6 @@ from sglang.srt.utils import set_weight_attrs
...
@@ -38,7 +38,6 @@ from sglang.srt.utils import set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
CustomOp
.
register
(
"silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
...
@@ -52,7 +51,6 @@ class SiluAndMul(CustomOp):
...
@@ -52,7 +51,6 @@ class SiluAndMul(CustomOp):
return
out
return
out
@
CustomOp
.
register
(
"gelu_and_mul"
)
class
GeluAndMul
(
CustomOp
):
class
GeluAndMul
(
CustomOp
):
def
__init__
(
self
,
approximate
=
"tanh"
):
def
__init__
(
self
,
approximate
=
"tanh"
):
super
().
__init__
()
super
().
__init__
()
...
...
python/sglang/srt/layers/layernorm.py
View file @
c1f401fc
...
@@ -36,7 +36,6 @@ from vllm.model_executor.custom_op import CustomOp
...
@@ -36,7 +36,6 @@ from vllm.model_executor.custom_op import CustomOp
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
CustomOp
.
register
(
"rmsnorm"
)
class
RMSNorm
(
CustomOp
):
class
RMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -79,7 +78,6 @@ class RMSNorm(CustomOp):
...
@@ -79,7 +78,6 @@ class RMSNorm(CustomOp):
return
x
,
residual
return
x
,
residual
@
CustomOp
.
register
(
"gemma_rmsnorm"
)
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c1f401fc
...
@@ -28,7 +28,6 @@ import torch
...
@@ -28,7 +28,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tp_group
,
get_tp_group
,
init_distributed_environment
,
init_distributed_environment
,
...
@@ -60,7 +59,6 @@ from sglang.srt.utils import (
...
@@ -60,7 +59,6 @@ from sglang.srt.utils import (
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
)
)
...
@@ -245,14 +243,12 @@ class ModelRunner:
...
@@ -245,14 +243,12 @@ class ModelRunner:
# Prepare the vllm model config
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_dummy_weight_loader
()
monkey_patch_vllm_model_config
()
self
.
load_config
=
LoadConfig
(
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
)
)
self
.
vllm_model_config
=
VllmModelConfig
(
self
.
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
server_args
.
model_path
,
model
=
self
.
server_args
.
model_path
,
task
=
"generate"
if
self
.
model_config
.
is_generation
else
"embedding"
,
quantization
=
self
.
server_args
.
quantization
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
...
@@ -267,17 +263,15 @@ class ModelRunner:
...
@@ -267,17 +263,15 @@ class ModelRunner:
)
)
self
.
dtype
=
self
.
vllm_model_config
.
dtype
self
.
dtype
=
self
.
vllm_model_config
.
dtype
self
.
vllm_config
=
VllmConfig
()
self
.
vllm_config
.
model_config
=
self
.
vllm_model_config
self
.
vllm_config
.
load_config
=
self
.
load_config
self
.
vllm_config
.
device_config
=
DeviceConfig
(
self
.
device
)
self
.
vllm_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
self
.
vllm_config
.
model_config
,
self
.
vllm_config
.
load_config
)
# Load the model
# Load the model
self
.
model
=
get_model
(
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
device_config
=
DeviceConfig
(
self
.
device
),
parallel_config
=
None
,
scheduler_config
=
None
,
lora_config
=
None
,
cache_config
=
None
,
)
)
self
.
sliding_window_size
=
(
self
.
sliding_window_size
=
(
self
.
model
.
get_attention_sliding_window_size
()
self
.
model
.
get_attention_sliding_window_size
()
...
@@ -312,7 +306,6 @@ class ModelRunner:
...
@@ -312,7 +306,6 @@ class ModelRunner:
# TODO: Use a better method to check this
# TODO: Use a better method to check this
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
model
=
model_path
,
model
=
model_path
,
task
=
"generate"
if
self
.
model_config
.
is_generation
else
"embedding"
,
quantization
=
self
.
server_args
.
quantization
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
...
...
python/sglang/srt/utils.py
View file @
c1f401fc
...
@@ -410,23 +410,37 @@ def monkey_patch_vllm_dummy_weight_loader():
...
@@ -410,23 +410,37 @@ def monkey_patch_vllm_dummy_weight_loader():
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""
"""
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader.loader
import
(
from
vllm.model_executor.model_loader.loader
import
(
CacheConfig
,
DeviceConfig
,
DummyModelLoader
,
DummyModelLoader
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
_initialize_model
,
_initialize_model
,
initialize_dummy_weights
,
initialize_dummy_weights
,
nn
,
nn
,
set_default_torch_dtype
,
set_default_torch_dtype
,
)
)
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
with
set_default_torch_dtype
(
vllm_config
.
model_config
.
dtype
):
self
,
with
torch
.
device
(
vllm_config
.
device_config
.
device
):
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model
=
_initialize_model
(
vllm_config
.
model_config
,
model_config
,
self
.
load_config
,
self
.
load_config
,
vllm_config
.
lora_config
,
lora_config
,
vllm_config
.
cache_config
,
cache_config
,
)
)
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
...
@@ -498,60 +512,6 @@ def maybe_set_triton_cache_manager() -> None:
...
@@ -498,60 +512,6 @@ def maybe_set_triton_cache_manager() -> None:
os
.
environ
[
"TRITON_CACHE_MANAGER"
]
=
manager
os
.
environ
[
"TRITON_CACHE_MANAGER"
]
=
manager
def
monkey_patch_vllm_model_config
():
from
typing
import
Dict
,
Set
,
Tuple
,
Union
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
TaskOption
,
_Task
def
_resolve_task
(
self
,
task_option
:
Union
[
TaskOption
,
_Task
],
hf_config
:
PretrainedConfig
,
)
->
Tuple
[
Set
[
_Task
],
_Task
]:
architectures
=
getattr
(
hf_config
,
"architectures"
,
[])
if
isinstance
(
architectures
,
str
):
architectures
=
[
architectures
]
non_generation_models
=
{
"LlamaEmbeddingModel"
,
"MistralModel"
,
"LlamaForSequenceClassification"
,
"LlamaForSequenceClassificationWithNormal_Weights"
,
"InternLM2ForRewardModel"
,
}
is_generation
=
not
any
(
arch
in
non_generation_models
for
arch
in
architectures
)
auto_map
=
getattr
(
hf_config
,
"auto_map"
,
{})
has_sequence_classification
=
any
(
"ForSequenceClassification"
in
v
for
v
in
auto_map
.
values
()
)
task_support
:
Dict
[
_Task
,
bool
]
=
{
"generate"
:
is_generation
,
"embedding"
:
(
not
is_generation
)
or
has_sequence_classification
,
}
supported_tasks_lst
=
[
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
]
supported_tasks
=
set
(
supported_tasks_lst
)
if
task_option
not
in
supported_tasks
:
msg
=
(
f
"This model does not support the '
{
task_option
}
' task. "
f
"Supported tasks:
{
supported_tasks
}
"
)
raise
ValueError
(
msg
)
selected_task
=
task_option
return
supported_tasks
,
selected_task
setattr
(
ModelConfig
,
"_resolve_task"
,
_resolve_task
)
class
CustomCacheManager
(
FileCacheManager
):
class
CustomCacheManager
(
FileCacheManager
):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def
__init__
(
self
,
key
,
override
=
False
,
dump
=
False
):
def
__init__
(
self
,
key
,
override
=
False
,
dump
=
False
):
...
...
test/srt/test_bench_serving.py
View file @
c1f401fc
import
sys
import
unittest
import
unittest
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -36,12 +35,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -36,12 +35,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
print
(
assert
res
[
"output_throughput"
]
>
1000
f
"Output throughput:
{
res
[
'output_throughput'
]
}
, Is greater than 1000:
{
res
[
'output_throughput'
]
>
1000
}
"
,
file
=
sys
.
stderr
,
)
# TODO(zhyncs) fix this
# assert res["output_throughput"] > 1000
def
test_offline_throughput_without_radix_cache
(
self
):
def
test_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
...
test/srt/test_nightly_gsm8k_eval.py
View file @
c1f401fc
import
json
import
os
import
unittest
import
unittest
from
datetime
import
datetime
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.srt.utils
import
kill_child_process
...
@@ -17,26 +14,6 @@ from sglang.test.test_utils import (
...
@@ -17,26 +14,6 @@ from sglang.test.test_utils import (
popen_launch_server
,
popen_launch_server
,
)
)
MODEL_SCORE_THRESHOLDS
=
{
"meta-llama/Llama-3.1-8B-Instruct"
:
0.8316
,
"mistralai/Mistral-7B-Instruct-v0.3"
:
0.5861
,
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
0.8672
,
"google/gemma-2-27b-it"
:
0.9227
,
"meta-llama/Llama-3.1-70B-Instruct"
:
0.9623
,
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
0.6415
,
"Qwen/Qwen2-57B-A14B-Instruct"
:
0.8791
,
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
:
0.8672
,
"neuralmagic/Mistral-7B-Instruct-v0.3-FP8"
:
0.5544
,
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
:
0.8356
,
"neuralmagic/gemma-2-2b-it-FP8"
:
0.6059
,
"neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8"
:
0.9504
,
"neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
:
0.6138
,
"neuralmagic/Qwen2-72B-Instruct-FP8"
:
0.9504
,
"neuralmagic/Qwen2-57B-A14B-Instruct-FP8"
:
0.8197
,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4"
:
0.8395
,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
:
0.8435
,
}
def
parse_models
(
model_string
):
def
parse_models
(
model_string
):
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
...
@@ -46,8 +23,10 @@ def launch_server(base_url, model, is_fp8, is_tp2):
...
@@ -46,8 +23,10 @@ def launch_server(base_url, model, is_fp8, is_tp2):
other_args
=
[
"--log-level-http"
,
"warning"
,
"--trust-remote-code"
]
other_args
=
[
"--log-level-http"
,
"warning"
,
"--trust-remote-code"
]
if
is_fp8
:
if
is_fp8
:
if
"Llama-3"
in
model
or
"gemma-2"
in
model
:
if
"Llama-3"
in
model
or
"gemma-2"
in
model
:
# compressed-tensors
other_args
.
extend
([
"--kv-cache-dtype"
,
"fp8_e5m2"
])
other_args
.
extend
([
"--kv-cache-dtype"
,
"fp8_e5m2"
])
elif
"Qwen2-72B-Instruct-FP8"
in
model
:
elif
"Qwen2-72B-Instruct-FP8"
in
model
:
# bug
other_args
.
extend
([
"--quantization"
,
"fp8"
])
other_args
.
extend
([
"--quantization"
,
"fp8"
])
else
:
else
:
other_args
.
extend
([
"--quantization"
,
"fp8"
,
"--kv-cache-dtype"
,
"fp8_e5m2"
])
other_args
.
extend
([
"--quantization"
,
"fp8"
,
"--kv-cache-dtype"
,
"fp8_e5m2"
])
...
@@ -69,49 +48,6 @@ def launch_server(base_url, model, is_fp8, is_tp2):
...
@@ -69,49 +48,6 @@ def launch_server(base_url, model, is_fp8, is_tp2):
return
process
return
process
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
def
check_model_scores
(
results
):
failed_models
=
[]
for
model
,
score
in
results
:
threshold
=
MODEL_SCORE_THRESHOLDS
.
get
(
model
)
if
threshold
is
None
:
print
(
f
"Warning: No threshold defined for model
{
model
}
"
)
continue
if
score
<
threshold
:
failed_models
.
append
(
f
"
\n
Score Check Failed:
{
model
}
\n
"
f
"Model
{
model
}
score (
{
score
:.
4
f
}
) is below threshold (
{
threshold
:.
4
f
}
)"
)
if
failed_models
:
raise
AssertionError
(
"
\n
"
.
join
(
failed_models
))
class
TestEvalAccuracyLarge
(
unittest
.
TestCase
):
class
TestEvalAccuracyLarge
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -132,9 +68,6 @@ class TestEvalAccuracyLarge(unittest.TestCase):
...
@@ -132,9 +68,6 @@ class TestEvalAccuracyLarge(unittest.TestCase):
kill_child_process
(
self
.
process
.
pid
,
include_self
=
True
)
kill_child_process
(
self
.
process
.
pid
,
include_self
=
True
)
def
test_mgsm_en_all_models
(
self
):
def
test_mgsm_en_all_models
(
self
):
is_first
=
True
all_results
=
[]
for
model_group
,
is_fp8
,
is_tp2
in
self
.
model_groups
:
for
model_group
,
is_fp8
,
is_tp2
in
self
.
model_groups
:
for
model
in
model_group
:
for
model
in
model_group
:
with
self
.
subTest
(
model
=
model
):
with
self
.
subTest
(
model
=
model
):
...
@@ -152,24 +85,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
...
@@ -152,24 +85,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
print
(
print
(
f
"
{
'='
*
42
}
\n
{
model
}
- metrics=
{
metrics
}
score=
{
metrics
[
'score'
]
}
\n
{
'='
*
42
}
\n
"
f
"
{
'='
*
42
}
\n
{
model
}
- metrics=
{
metrics
}
score=
{
metrics
[
'score'
]
}
\n
{
'='
*
42
}
\n
"
)
)
# loosely threshold
write_results_to_json
(
model
,
metrics
,
"w"
if
is_first
else
"a"
)
assert
metrics
[
"score"
]
>
0.5
,
f
"score=
{
metrics
[
'score'
]
}
<= 0.5"
is_first
=
False
all_results
.
append
((
model
,
metrics
[
"score"
]))
self
.
tearDown
()
self
.
tearDown
()
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
print
(
"
\n
Final Results from results.json:"
)
print
(
json
.
dumps
(
json
.
load
(
f
),
indent
=
2
))
except
Exception
as
e
:
print
(
f
"Error reading results.json:
{
e
}
"
)
# Check all scores after collecting all results
check_model_scores
(
all_results
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/srt/test_torch_compile.py
View file @
c1f401fc
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
print
(
res
[
"text"
])
print
(
res
[
"text"
])
throughput
=
max_tokens
/
(
tok
-
tic
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
15
1
)
self
.
assertGreaterEqual
(
throughput
,
15
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile_moe.py
View file @
c1f401fc
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
print
(
f
"
{
res
=
}
"
)
print
(
f
"
{
res
=
}
"
)
throughput
=
max_tokens
/
(
tok
-
tic
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
2
8
9
)
self
.
assertGreaterEqual
(
throughput
,
29
0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment