Unverified Commit 4eb4b401 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

update and simplify CustomOp (#3249)

parent 17dbf976
import torch
from torch import nn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
raise NotImplementedError
def forward_xpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_hpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_rocm:
return self.forward_hip
else:
return self.forward_native
......@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
......@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator
......@@ -29,14 +29,11 @@ if is_cuda_available():
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.custom_op import CustomOp
logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp):
def __init__(
self,
......@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp):
def __init__(
self,
......
......@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
......@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight
@register_custom_op("sglang_unquantized_ep_moe")
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
......
......@@ -5,14 +5,13 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
......@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
@register_custom_op("sglang_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
......
......@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
......@@ -59,7 +58,6 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
@register_custom_op("sglang_rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
......
......@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
import torch
import tqdm
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment