import torch import torch.nn as nn import torch.nn.functional as F # 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回) def grab_first_if_tuple(x): return x[0] if isinstance(x, tuple) else x class ParallelGatedMLP(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx multiple_of = config.get("inner_size_multiple_of", 64) self.act_type = config.get("mlp_activation", "gelu") if self.act_type == "gelu": self.act = F.gelu elif self.act_type == "silu": self.act = F.silu else: raise NotImplementedError if self.layer_idx > 0 and config.get("evo2_style_activations", False): self.act = nn.Identity() inner_size = 11264 self.l1 = nn.Linear( in_features=config.get("hidden_size", 4096), out_features=inner_size, bias=False, ) self.l2 = nn.Linear( in_features=config.get("hidden_size", 4096), out_features=inner_size, bias=False, ) self.l3 = nn.Linear( in_features=inner_size, out_features=config.get("hidden_size", 4096), bias=False, ) # 确保权重是 contiguous(通常 Linear 默认就是,但保险起见) self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous()) self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous()) self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous()) def forward(self, z): z1, z2 = self.l1(z), self.l2(z) return z1, z2 # === 示例调用 === if __name__ == "__main__": # 模拟配置 config = { "hidden_size": 4096, "mlp_activation": "silu", "model_parallel_size:q": 1, "evo2_style_activations": False, } layer_idx = 0 # 创建模型实例 model = ParallelGatedMLP(config, layer_idx) # 将模型转换为 bfloat16 model = model.to(dtype=torch.bfloat16, device="cuda:0") # 创建输入张量(batch=1, seq_len=1, hidden=4096) device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100) x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device) # 推理 with torch.no_grad(): for i in range(10): output = model(x)