Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b509db58
Unverified
Commit
b509db58
authored
Nov 24, 2024
by
Yineng Zhang
Committed by
GitHub
Nov 24, 2024
Browse files
feat: remove the dependency on FusedMoE (#2153)
parent
dbe17293
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1602 additions
and
7 deletions
+1602
-7
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+15
-5
python/sglang/srt/layers/triton_fused_moe/__init__.py
python/sglang/srt/layers/triton_fused_moe/__init__.py
+44
-0
python/sglang/srt/layers/triton_fused_moe/configs/README
python/sglang/srt/layers/triton_fused_moe/configs/README
+10
-0
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
+858
-0
python/sglang/srt/layers/triton_fused_moe/layer.py
python/sglang/srt/layers/triton_fused_moe/layer.py
+631
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+43
-1
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
b509db58
...
...
@@ -57,12 +57,23 @@ __all__ = [
"QUANTIZATION_METHODS"
,
]
"""
def fp8_get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
)
from
sglang.srt.layers.triton_fused_moe.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
...
...
@@ -71,4 +82,3 @@ def fp8_get_quant_method(
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
"""
python/sglang/srt/layers/triton_fused_moe/__init__.py
0 → 100644
View file @
b509db58
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
import
sglang.srt.layers.triton_fused_moe.fused_moe
# noqa
from
sglang.srt.layers.triton_fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
get_config_file_name
,
grouped_topk
,
)
from
sglang.srt.layers.triton_fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
contextmanager
def
override_config
(
config
):
global
_config
old_config
=
_config
_config
=
config
yield
_config
=
old_config
def
get_config
()
->
Optional
[
Dict
[
str
,
Any
]]:
return
_config
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
"override_config"
,
"get_config"
,
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
"get_config_file_name"
,
"grouped_topk"
,
]
python/sglang/srt/layers/triton_fused_moe/configs/README
0 → 100644
View file @
b509db58
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
0 → 100644
View file @
b509db58
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/triton_fused_moe/layer.py
0 → 100644
View file @
b509db58
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/deepseek_v2.py
View file @
b509db58
...
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/utils.py
View file @
b509db58
...
...
@@ -31,7 +31,7 @@ import time
import
warnings
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
import
numpy
as
np
import
psutil
...
...
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch.func
import
functional_call
from
torch.library
import
Library
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
triton.runtime.cache
import
(
FileCacheManager
,
...
...
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
def
crash_on_warnings
():
# Crash on warning if we are running CI tests
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
if
hasattr
(
torch
,
"cuda"
)
and
torch
.
cuda
.
is_available
():
return
torch
.
cuda
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"hip"
)
and
torch
.
hip
.
is_available
():
return
torch
.
hip
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
return
torch
.
xpu
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"hpu"
)
and
torch
.
hpu
.
is_available
():
return
torch
.
hpu
.
get_device_name
(
device_id
)
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
List
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
):
import
torch.library
if
hasattr
(
torch
.
library
,
"infer_schema"
):
schema_str
=
torch
.
library
.
infer_schema
(
op_func
,
mutates_args
=
mutates_args
)
else
:
# for pytorch 2.4
import
torch._custom_op.impl
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
my_lib
=
target_lib
or
sglang_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment