Commit 5c004388 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix run error

parent fd46df33
......@@ -65,9 +65,9 @@ class FixFunctionalizationPass(VllmInductorPass):
# elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'residual'}
# self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
# self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
# torch.ops._C.rms_norm_static_fp8_quant.default,
......
......@@ -82,9 +82,9 @@ class QuantKey(NamedTuple):
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym:
......
......@@ -686,7 +686,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.size(0),
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.size(1) if not use_nn_moe else B.size[2], META['BLOCK_SIZE_N']), )
B.size(1) if not use_nn_moe else B.size(2), META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
......@@ -813,7 +813,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1) if not use_nn_moe else B.size[2],
B.size(1) if not use_nn_moe else B.size(2),
A.size(1),
EM,
num_tokens,
......@@ -1485,7 +1485,7 @@ def fused_experts_impl(
assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch")
elif use_nn_moe:
assert hidden_states.size[1] == w1.size[1], "Hidden size mismatch"
assert hidden_states.size(1) == w1.size(1), "Hidden size mismatch"
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
......@@ -1500,7 +1500,7 @@ def fused_experts_impl(
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
......
......@@ -58,11 +58,10 @@ else:
FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore
if is_rocm_aiter_moe_enabled():
# if is_rocm_aiter_moe_enabled():
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
# rocm_aiter_grouped_topk as grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
elif current_platform.is_cpu():
if current_platform.is_cpu():
pass
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment