Unverified Commit 4eb4b401 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

update and simplify CustomOp (#3249)

parent 17dbf976
import torch
from torch import nn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
raise NotImplementedError
def forward_xpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_hpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_rocm:
return self.forward_hip
else:
return self.forward_native
...@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available ...@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.model_executor.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
...@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp): ...@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
return out return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"): def __init__(self, approximate="tanh"):
super().__init__() super().__init__()
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator
...@@ -29,14 +29,11 @@ if is_cuda_available(): ...@@ -29,14 +29,11 @@ if is_cuda_available():
rmsnorm, rmsnorm,
) )
from vllm.model_executor.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
def __init__( def __init__(
self, self,
...@@ -79,7 +76,6 @@ class RMSNorm(CustomOp): ...@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
return x, residual return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
self, self,
......
...@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple ...@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton, grouped_gemm_triton,
post_reorder_triton_kernel, post_reorder_triton_kernel,
...@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module): ...@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
@register_custom_op("sglang_unquantized_ep_moe")
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights( def create_weights(
self, self,
......
...@@ -5,14 +5,13 @@ from enum import Enum ...@@ -5,14 +5,13 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
@register_custom_op("sglang_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
......
...@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available() _is_cuda_available = is_cuda_available()
...@@ -59,7 +58,6 @@ def _apply_rotary_emb( ...@@ -59,7 +58,6 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2) return torch.stack((o1, o2), dim=-1).flatten(-2)
@register_custom_op("sglang_rotary_embedding")
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
......
...@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable ...@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
import tqdm import tqdm
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment