Commit 92c6171e authored by zhuwenwen's avatar zhuwenwen
Browse files

update op

parent a857453f
...@@ -43,8 +43,6 @@ from vllm.utils import direct_register_custom_op ...@@ -43,8 +43,6 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
if envs.VLLM_USE_LIGHTOP:
from lightop import op
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0') os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
...@@ -1771,6 +1769,7 @@ def fused_experts_impl( ...@@ -1771,6 +1769,7 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick: if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
from lightop import op as op
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], None, routed_scaling_factor) out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], None, routed_scaling_factor)
# else: # else:
......
...@@ -43,8 +43,6 @@ from vllm.platforms.interface import CpuArchEnum ...@@ -43,8 +43,6 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -1287,6 +1285,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1287,6 +1285,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None assert num_expert_group is not None
if use_fused_gate: if use_fused_gate:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
topk_weights, topk_ids = op.moe_fused_gate( topk_weights, topk_ids = op.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
......
...@@ -9,8 +9,6 @@ from vllm.triton_utils import tl, triton ...@@ -9,8 +9,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up from vllm.utils import cdiv, round_up
import vllm.envs as envs import vllm.envs as envs
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
@triton.jit @triton.jit
...@@ -234,6 +232,7 @@ def moe_align_block_size( ...@@ -234,6 +232,7 @@ def moe_align_block_size(
device=topk_ids.device) device=topk_ids.device)
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None, None, None) expert_ids, num_tokens_post_pad, None, None, None)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment