# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities to create common models """ import torch from torch import nn def get_model_size(model: nn.Module, scale="auto"): n_params = sum(p.numel() for p in model.parameters()) if scale == "auto": if n_params > 1e9: scale = "B" elif n_params > 1e6: scale = "M" elif n_params > 1e3: scale = "K" else: scale = "" if scale == "B": n_params = n_params / 1e9 elif scale == "M": n_params = n_params / 1e6 elif scale == "K": n_params = n_params / 1e3 elif scale == "": pass else: raise NotImplementedError(f"Unknown scale {scale}") return n_params, scale def print_model_size(model: nn.Module, name: str = None): n_params, scale = get_model_size(model, scale="auto") if name is None: name = model.__class__.__name__ print(f"{name} contains {n_params:.2f}{scale} parameters") def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)