Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e3938b2f
Unverified
Commit
e3938b2f
authored
Nov 24, 2024
by
Yineng Zhang
Committed by
GitHub
Nov 24, 2024
Browse files
feat: update other MoE models deps (#2156)
parent
c211e7b6
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
28 additions
and
14 deletions
+28
-14
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+1
-6
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
+2
-0
python/sglang/srt/layers/triton_fused_moe/layer.py
python/sglang/srt/layers/triton_fused_moe/layer.py
+4
-2
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-1
python/sglang/srt/models/olmoe.py
python/sglang/srt/models/olmoe.py
+1
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-1
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+15
-0
No files found.
python/sglang/srt/layers/fused_moe/layer.py
View file @
e3938b2f
...
@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
],
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
...
...
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
View file @
e3938b2f
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
"""Fused MoE kernel."""
"""Fused MoE kernel."""
import
functools
import
functools
...
...
python/sglang/srt/layers/triton_fused_moe/layer.py
View file @
e3938b2f
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
if
torch
.
cuda
.
is_available
()
or
torch
.
hip
.
is_available
():
if
torch
.
cuda
.
is_available
()
or
torch
.
hip
.
is_available
():
from
.fused_moe
import
fused_experts
from
sglang.srt.layers.triton_fused_moe
.fused_moe
import
fused_experts
else
:
else
:
fused_experts
=
None
# type: ignore
fused_experts
=
None
# type: ignore
...
@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
):
):
from
vllm.model_executor
.layers.fused_moe.fused_moe
import
(
from
sglang.srt
.layers.
triton_
fused_moe.fused_moe
import
(
fused_topk
,
fused_topk
,
grouped_topk
,
grouped_topk
,
)
)
...
...
python/sglang/srt/models/dbrx.py
View file @
e3938b2f
...
@@ -24,7 +24,6 @@ from vllm.distributed import (
...
@@ -24,7 +24,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
...
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
fused_moe
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
ParallelLMHead
,
...
...
python/sglang/srt/models/deepseek.py
View file @
e3938b2f
...
@@ -26,7 +26,6 @@ from vllm.distributed import (
...
@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import (
...
@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
fused_moe
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/mixtral.py
View file @
e3938b2f
...
@@ -22,7 +22,6 @@ import torch
...
@@ -22,7 +22,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.triton_fused_moe
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/olmoe.py
View file @
e3938b2f
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm
...
@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
e3938b2f
...
@@ -26,7 +26,6 @@ from vllm.distributed import (
...
@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.triton_fused_moe
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/xverse_moe.py
View file @
e3938b2f
...
@@ -24,7 +24,6 @@ from vllm.distributed import (
...
@@ -24,7 +24,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
fused_moe
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/utils.py
View file @
e3938b2f
...
@@ -957,6 +957,21 @@ def direct_register_custom_op(
...
@@ -957,6 +957,21 @@ def direct_register_custom_op(
fake_impl
:
Optional
[
Callable
]
=
None
,
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
):
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
import
torch.library
import
torch.library
if
hasattr
(
torch
.
library
,
"infer_schema"
):
if
hasattr
(
torch
.
library
,
"infer_schema"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment