Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
66301e12
Unverified
Commit
66301e12
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Improve code styles (#4021)
parent
ac238727
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
74 additions
and
229 deletions
+74
-229
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+5
-5
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+1
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+1
-6
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+0
-1
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+0
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-8
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+0
-114
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-2
python/sglang/test/few_shot_gsm8k.py
python/sglang/test/few_shot_gsm8k.py
+4
-1
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+55
-80
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
66301e12
...
@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
...
@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
[
elif
config
.
architectures
[
0
]
in
[
"Grok1ForCausalLM"
,
"Grok1ForCausalLM"
,
"Grok1ImgGen"
,
"Grok1ImgGen"
,
...
@@ -39,11 +44,6 @@ def get_model_config(model_name: str, tp_size: int):
...
@@ -39,11 +44,6 @@ def get_model_config(model_name: str, tp_size: int):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
else
:
# Default: Mixtral
# Default: Mixtral
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
...
...
python/sglang/bench_serving.py
View file @
66301e12
...
@@ -393,7 +393,7 @@ async def async_request_sglang_generate(
...
@@ -393,7 +393,7 @@ async def async_request_sglang_generate(
output
.
itl
.
extend
([
adjust_itl
]
*
num_new_tokens
)
output
.
itl
.
extend
([
adjust_itl
]
*
num_new_tokens
)
most_recent_timestamp
=
timestamp
most_recent_timestamp
=
timestamp
generated_text
=
data
[
"text"
]
last_output_len
=
output_len
output
.
generated_text
=
generated_text
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
success
=
True
...
...
python/sglang/lang/backend/runtime_endpoint.py
View file @
66301e12
...
@@ -329,12 +329,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -329,12 +329,7 @@ class RuntimeEndpoint(BaseBackend):
def
compute_normalized_prompt_logprobs
(
input_logprobs
):
def
compute_normalized_prompt_logprobs
(
input_logprobs
):
values
=
[
x
[
0
]
for
x
in
input_logprobs
if
x
[
0
]]
values
=
[
x
[
0
]
for
x
in
input_logprobs
if
x
[
0
]]
try
:
return
sum
(
values
)
/
len
(
values
)
return
sum
(
values
)
/
len
(
values
)
except
TypeError
:
print
(
f
"
{
input_logprobs
=
}
"
,
flush
=
True
)
print
(
f
"
{
input_logprobs
[
0
]
=
}
"
,
flush
=
True
)
exit
(
-
1
)
class
Runtime
:
class
Runtime
:
...
...
python/sglang/srt/layers/logits_processor.py
View file @
66301e12
...
@@ -64,7 +64,7 @@ class LogitsProcessorOutput:
...
@@ -64,7 +64,7 @@ class LogitsProcessorOutput:
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token]
# The logprobs of input tokens. shape: [#token]
input_token_logprobs
:
torch
.
Tensor
=
None
input_token_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val
:
List
=
None
input_top_logprobs_val
:
List
=
None
input_top_logprobs_idx
:
List
=
None
input_top_logprobs_idx
:
List
=
None
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
66301e12
...
@@ -181,7 +181,6 @@ class EPMoE(torch.nn.Module):
...
@@ -181,7 +181,6 @@ class EPMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
if
self
.
grouped_gemm_runner
is
None
:
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
self
.
grouped_gemm_runner
=
GroupedGemmRunner
(
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
66301e12
...
@@ -198,8 +198,6 @@ class DataParallelController:
...
@@ -198,8 +198,6 @@ class DataParallelController:
self
.
max_total_num_tokens
=
scheduler_info
[
0
][
"max_total_num_tokens"
]
self
.
max_total_num_tokens
=
scheduler_info
[
0
][
"max_total_num_tokens"
]
self
.
max_req_input_len
=
scheduler_info
[
0
][
"max_req_input_len"
]
self
.
max_req_input_len
=
scheduler_info
[
0
][
"max_req_input_len"
]
print
(
f
"
{
scheduler_info
=
}
"
)
def
round_robin_scheduler
(
self
,
req
):
def
round_robin_scheduler
(
self
,
req
):
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
...
@@ -222,7 +220,6 @@ class DataParallelController:
...
@@ -222,7 +220,6 @@ class DataParallelController:
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
),
),
):
):
logger
.
info
(
"dispatching"
)
self
.
dispatching
(
recv_req
)
self
.
dispatching
(
recv_req
)
else
:
else
:
# Send other control messages to first worker of tp group
# Send other control messages to first worker of tp group
...
...
python/sglang/srt/managers/io_struct.py
View file @
66301e12
...
@@ -158,7 +158,7 @@ class GenerateReqInput:
...
@@ -158,7 +158,7 @@ class GenerateReqInput:
# Expand parallel_sample_num
# Expand parallel_sample_num
num
=
self
.
batch_size
*
self
.
parallel_sample_num
num
=
self
.
batch_size
*
self
.
parallel_sample_num
if
self
.
image_data
is
None
:
if
not
self
.
image_data
:
self
.
image_data
=
[
None
]
*
num
self
.
image_data
=
[
None
]
*
num
elif
not
isinstance
(
self
.
image_data
,
list
):
elif
not
isinstance
(
self
.
image_data
,
list
):
self
.
image_data
=
[
self
.
image_data
]
*
num
self
.
image_data
=
[
self
.
image_data
]
*
num
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
66301e12
...
@@ -282,6 +282,8 @@ class Req:
...
@@ -282,6 +282,8 @@ class Req:
# If we want to abort the request in the middle of the event loop, set this to true
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self
.
to_abort
=
False
self
.
to_abort
=
False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self
.
to_abort_message
:
str
=
"Unknown error"
self
.
stream
=
stream
self
.
stream
=
stream
self
.
eos_token_ids
=
eos_token_ids
self
.
eos_token_ids
=
eos_token_ids
...
@@ -359,8 +361,6 @@ class Req:
...
@@ -359,8 +361,6 @@ class Req:
# The tokens is prefilled but need to be considered as decode tokens
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
# and should be updated for the decode logprobs
self
.
last_update_decode_tokens
=
0
self
.
last_update_decode_tokens
=
0
# The relative logprob_start_len in an extend batch
self
.
extend_logprob_start_len
=
0
# Embedding (return values)
# Embedding (return values)
self
.
embedding
=
None
self
.
embedding
=
None
...
@@ -377,9 +377,6 @@ class Req:
...
@@ -377,9 +377,6 @@ class Req:
self
.
spec_verify_ct
=
0
self
.
spec_verify_ct
=
0
self
.
lora_path
=
lora_path
self
.
lora_path
=
lora_path
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self
.
to_abort_message
:
str
=
"Unknown error"
@
property
@
property
def
seqlen
(
self
):
def
seqlen
(
self
):
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
66301e12
...
@@ -358,7 +358,6 @@ class Scheduler:
...
@@ -358,7 +358,6 @@ class Scheduler:
self
.
cum_spec_accept_count
=
0
self
.
cum_spec_accept_count
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
return_health_check_ct
=
0
self
.
return_health_check_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
if
self
.
device
==
"cpu"
:
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
...
@@ -444,11 +443,6 @@ class Scheduler:
...
@@ -444,11 +443,6 @@ class Scheduler:
},
},
)
)
# The largest prefill length of a single request
self
.
_largest_prefill_len
:
int
=
0
# The largest context length (prefill + generation) of a single request
self
.
_largest_prefill_decode_len
:
int
=
0
# Init request dispatcher
# Init request dispatcher
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
[
...
@@ -2309,8 +2303,6 @@ def run_scheduler_process(
...
@@ -2309,8 +2303,6 @@ def run_scheduler_process(
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
parent_process
=
psutil
.
Process
().
parent
()
# Create a scheduler and run the event loop
# Create a scheduler and run the event loop
try
:
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
)
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
dp_rank
)
...
...
python/sglang/srt/metrics/collector.py
View file @
66301e12
...
@@ -238,120 +238,6 @@ class TokenizerMetricsCollector:
...
@@ -238,120 +238,6 @@ class TokenizerMetricsCollector:
],
],
)
)
self
.
histogram_prefill_prealloc_duration
=
Histogram
(
name
=
"sglang:prefill_prealloc_duration_seconds"
,
documentation
=
"Histogram of prefill prealloc duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.3
,
0.5
,
0.7
,
0.9
,
1
,
2
,
4
,
6
,
8
,
10
,
20
,
40
,
60
,
80
,
120
,
160
,
],
)
self
.
histogram_prefill_queue_duration
=
Histogram
(
name
=
"sglang:prefill_queue_duration_seconds"
,
documentation
=
"Histogram of prefill queue duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.3
,
0.5
,
0.7
,
0.9
,
2
,
4
,
8
,
16
,
64
,
],
)
self
.
histogram_prefill_forward_duration
=
Histogram
(
name
=
"sglang:prefill_forward_duration_seconds"
,
documentation
=
"Histogram of prefill forward duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.3
,
0.5
,
0.7
,
0.9
,
2
,
4
,
8
,
16
,
64
,
],
)
self
.
histogram_prefill_transfer_duration
=
Histogram
(
name
=
"sglang:prefill_transfer_duration_seconds"
,
documentation
=
"Histogram of prefill transfer duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.050
,
0.100
,
0.150
,
0.200
,
0.300
,
0.400
,
0.500
,
1.000
,
2.000
,
],
)
self
.
histogram_decode_prealloc_duration
=
Histogram
(
name
=
"sglang:decode_prealloc_duration_seconds"
,
documentation
=
"Histogram of decode prealloc duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.3
,
0.5
,
0.7
,
0.9
,
2
,
4
,
8
,
16
,
64
,
],
)
self
.
histogram_decode_queue_duration
=
Histogram
(
name
=
"sglang:decode_queue_duration_seconds"
,
documentation
=
"Histogram of decode queue duration in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
0.1
,
0.3
,
0.5
,
0.7
,
0.9
,
2
,
4
,
8
,
16
,
64
,
],
)
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
data
)
histogram
.
labels
(
**
self
.
labels
).
observe
(
data
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
66301e12
...
@@ -284,7 +284,9 @@ class ForwardBatch:
...
@@ -284,7 +284,9 @@ class ForwardBatch:
):
):
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_num_tokens
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_num_tokens
,
)
)
else
:
else
:
positions
,
ret
.
extend_start_loc
=
compute_position_torch
(
positions
,
ret
.
extend_start_loc
=
compute_position_torch
(
...
...
python/sglang/srt/server_args.py
View file @
66301e12
...
@@ -62,7 +62,6 @@ class ServerArgs:
...
@@ -62,7 +62,6 @@ class ServerArgs:
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
skip_tokenizer_init
:
bool
=
False
# Port for the HTTP server
# Port for the HTTP server
host
:
str
=
"127.0.0.1"
host
:
str
=
"127.0.0.1"
...
@@ -563,7 +562,7 @@ class ServerArgs:
...
@@ -563,7 +562,7 @@ class ServerArgs:
"--download-dir"
,
"--download-dir"
,
type
=
str
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
default
=
ServerArgs
.
download_dir
,
help
=
"Model download directory."
,
help
=
"Model download directory
for huggingface
."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--base-gpu-id"
,
"--base-gpu-id"
,
...
...
python/sglang/test/few_shot_gsm8k.py
View file @
66301e12
...
@@ -93,9 +93,11 @@ def run_eval(args):
...
@@ -93,9 +93,11 @@ def run_eval(args):
tic
=
time
.
time
()
tic
=
time
.
time
()
states
=
few_shot_gsm8k
.
run_batch
(
states
=
few_shot_gsm8k
.
run_batch
(
arguments
,
arguments
,
temperature
=
0
,
temperature
=
args
.
temperature
if
hasattr
(
args
,
"temperature"
)
else
0
,
num_threads
=
args
.
parallel
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
progress_bar
=
True
,
return_logprob
=
getattr
(
args
,
"return_logprob"
,
None
),
logprob_start_len
=
getattr
(
args
,
"logprob_start_len"
,
None
),
)
)
latency
=
time
.
time
()
-
tic
latency
=
time
.
time
()
-
tic
...
@@ -141,5 +143,6 @@ if __name__ == "__main__":
...
@@ -141,5 +143,6 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--parallel"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--parallel"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
0.0
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
run_eval
(
args
)
run_eval
(
args
)
sgl-kernel/src/sgl-kernel/__init__.py
View file @
66301e12
...
@@ -8,16 +8,19 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
...
@@ -8,16 +8,19 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
,
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
,
mode
=
ctypes
.
RTLD_GLOBAL
,
mode
=
ctypes
.
RTLD_GLOBAL
,
)
)
from
.version
import
__version__
if
torch
.
version
.
hip
is
not
None
:
from
sgl_kernel.version
import
__version__
if
torch
.
version
.
cuda
:
from
sgl_kernel.ops
import
(
from
sgl_kernel.ops
import
(
all_reduce_reg
,
all_reduce_unreg
,
allocate_meta_buffer
,
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
bmm_fp8
,
dispose
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
...
@@ -25,63 +28,32 @@ if torch.version.hip is not None:
...
@@ -25,63 +28,32 @@ if torch.version.hip is not None:
gemma_fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
get_graph_buffer_ipc_meta
,
get_meta_buffer_ipc_handle
,
init_custom_reduce
,
init_custom_ar
,
int8_scaled_mm
,
int8_scaled_mm
,
lightning_attention_decode
,
lightning_attention_decode
,
meta_size
,
min_p_sampling_from_probs
,
min_p_sampling_from_probs
,
moe_align_block_size
,
moe_align_block_size
,
register_buffer
,
register_graph_buffers
,
register_graph_buffers
,
rmsnorm
,
rmsnorm
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
)
__all__
=
[
"all_reduce_reg"
,
"all_reduce_unreg"
,
"allocate_meta_buffer"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"dispose"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"get_meta_buffer_ipc_handle"
,
"init_custom_ar"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"meta_size"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_buffer"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
]
else
:
else
:
assert
torch
.
version
.
hip
from
sgl_kernel.ops
import
(
from
sgl_kernel.ops
import
(
all_reduce_reg
,
all_reduce_unreg
,
allocate_meta_buffer
,
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
bmm_fp8
,
build_tree_kernel
,
dispose
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
...
@@ -89,23 +61,26 @@ else:
...
@@ -89,23 +61,26 @@ else:
gemma_fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
get_meta_buffer_ipc_handle
,
init_custom_ar
,
int8_scaled_mm
,
int8_scaled_mm
,
lightning_attention_decode
,
lightning_attention_decode
,
meta_size
,
min_p_sampling_from_probs
,
min_p_sampling_from_probs
,
moe_align_block_size
,
moe_align_block_size
,
register_buffer
,
register_graph_buffers
,
register_graph_buffers
,
rmsnorm
,
rmsnorm
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
)
__all__
=
[
__all__
=
[
"__version__"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"cublas_grouped_gemm"
,
...
@@ -135,4 +110,4 @@ else:
...
@@ -135,4 +110,4 @@ else:
"top_k_top_p_sampling_from_probs"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
"tree_speculative_sampling_target_only"
,
]
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment