Commit 65e1d22d authored by zhuwenwen's avatar zhuwenwen
Browse files

fix moe run error

parent 661623f0
......@@ -25,4 +25,105 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
\ No newline at end of file
self.moe_quant_config: FusedMoEQuantConfig | None = None
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def uses_weight_scale_2_pattern(self) -> bool:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return False
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize"
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
)
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
raise NotImplementedError
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return None
@property
def supports_eplb(self) -> bool:
return False
@property
def allow_inplace(self) -> bool:
return False
@property
def method_name(self) -> str:
return self.__class__.__name__
@property
def is_monolithic(self) -> bool:
return False
# @abstractmethod
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
\ No newline at end of file
......@@ -1069,7 +1069,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
else:
# Overlap shared expert compute with all2all dispatch.
......@@ -1082,7 +1081,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
# TODO(lucas): refactor this in the alternative schedules followup
......@@ -1361,4 +1359,4 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
)
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment