from .palm_small import model model.cfg.dim = 4096 model.cfg.depth = 32 model.cfg.dim_head = 256 model.cfg.num_heads = 16