from dataclasses import dataclass
from typing import Callable, Optional
from functools import partial

@dataclass
class SubmoduleCallables:
    """
    Holds references to forward, dgrad, and dw (weight-grad) callables
    for a particular submodule.
    """

    def raise_not_implemented(name: str):
        raise NotImplementedError(f"{name} not implemented.")

    forward: Optional[Callable] = partial(raise_not_implemented, "forward")
    dw: Optional[Callable] = partial(raise_not_implemented, "dw")
    is_moe: bool = False
    is_deepep: bool = False


@dataclass
class TransformerLayerSubmoduleCallables:
    """
    Collects the SubmoduleMethods for each of the submodules:
    'attention', 'dispatch', 'mlp', 'combine'.
    """

    attention: SubmoduleCallables
    dispatch: SubmoduleCallables
    mlp: SubmoduleCallables
    combine: SubmoduleCallables
    is_moe: bool = False
    is_deepep: bool = False

    def as_array(self):
        return [self.attention, self.dispatch, self.mlp, self.combine]

    def __post_init__(self):
        for submodule in self.as_array():
            submodule.is_moe = self.is_moe
            submodule.is_deepep = self.is_deepep
