# Copyright (c) DP Technology. # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib import importlib import logging import os import sys import warnings from functools import partial from typing import List, Callable, Any, Dict try: import unicore_fused_multi_tensor HAS_MULTI_TENSOR = True except: print("please install latest fused_ops to get multi_tensor.") HAS_MULTI_TENSOR = False try: import unicore_fused_rounding HAS_FUSED_ROUNDING = True except: print("please install latest fused_ops to get fused_rounding.") HAS_FUSED_ROUNDING = False import torch import torch.nn.functional as F logger = logging.getLogger(__name__) def apply_to_sample(f, sample): if hasattr(sample, "__len__") and len(sample) == 0: return {} def _apply(x): if torch.is_tensor(x): return f(x) elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] elif isinstance(x, tuple): return tuple(_apply(x) for x in x) elif isinstance(x, set): return {_apply(x) for x in x} else: return x return _apply(sample) def move_to_cuda(sample, device=None): device = device or torch.cuda.current_device() def _move_to_cuda(tensor): # non_blocking is ignored if tensor is not pinned, so we can always set # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) return tensor.to(device=device, non_blocking=True) return apply_to_sample(_move_to_cuda, sample) def move_to_cpu(sample): def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. if tensor.dtype in {torch.bfloat16, torch.float16}: tensor = tensor.to(dtype=torch.float32) return tensor.cpu() return apply_to_sample(_move_to_cpu, sample) def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: per_device_grads = {} norms = [] for grad in grads: device = grad.device dtype = grad.dtype if device not in per_device_grads: per_device_grads[device] = {} if dtype not in per_device_grads[device]: per_device_grads[device][dtype] = [] per_device_grads[device][dtype].append(grad) for device in per_device_grads.keys(): for dtype in per_device_grads[device].keys(): cur_grads = per_device_grads[device][dtype] if HAS_MULTI_TENSOR and device.type == "cuda": norm = unicore_fused_multi_tensor.l2norm( chunk_size, [cur_grads] ) norms.append(norm) else: norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_grads] total_norm = torch.norm(torch.stack(norms), p=2, dtype=torch.float32) return total_norm @torch.no_grad() def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if isinstance(params, torch.Tensor): params = [params] params = list(params) grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] if len(grads) == 0: if len(params) > 0: return params[0].new_tensor(0.0) else: return torch.tensor(0.0) if len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: total_norm = multi_tensor_total_norm(grads) if aggregate_norm_fn is not None: total_norm = aggregate_norm_fn(total_norm) if max_norm > 0: max_norm = float(max_norm) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) for g in grads: g.mul_(clip_coef) return total_norm def import_user_module(args): module_path = getattr(args, "user_dir", None) if module_path is not None: module_path = os.path.abspath(args.user_dir) if not os.path.exists(module_path) and not os.path.isfile(os.path.dirname(module_path)): unicore_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(unicore_rel_path): module_path = unicore_rel_path else: unicore_rel_path = os.path.join( os.path.dirname(__file__), "..", args.user_dir ) if os.path.exists(unicore_rel_path): module_path = unicore_rel_path else: raise FileNotFoundError(module_path) # ensure that user modules are only imported once import_user_module.memo = getattr(import_user_module, "memo", set()) if module_path not in import_user_module.memo: import_user_module.memo.add(module_path) module_parent, module_name = os.path.split(module_path) if module_name not in sys.modules: sys.path.insert(0, module_parent) importlib.import_module(module_name) else: raise ImportError( "Failed to import --user-dir={} because the corresponding module name " "({}) is not globally unique. Please rename the directory to " "something unique and try again.".format(module_path, module_name) ) def get_activation_fn(activation: str) -> Callable: """ Returns the activation function corresponding to `activation` """ if activation == "relu": return F.relu elif activation == "gelu": return F.gelu elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x else: raise RuntimeError("--activation-fn {} not supported".format(activation)) def get_available_activation_fns() -> List: return [ "relu", "gelu", "tanh", "linear", ] def has_parameters(module): try: next(module.parameters()) return True except StopIteration: return False def get_rng_state(): state = {"torch_rng_state": torch.get_rng_state()} if torch.cuda.is_available(): state["cuda_rng_state"] = torch.cuda.get_rng_state() return state def set_rng_state(state): torch.set_rng_state(state["torch_rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(state["cuda_rng_state"]) @contextlib.contextmanager def torch_seed(seed, *addl_seeds): """Context manager which seeds the NumPy PRNG with the specified seed and restores the state afterward""" if seed is None: yield return def check_seed(s): assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 check_seed(seed) if len(addl_seeds) > 0: for s in addl_seeds: check_seed(s) seed = int(hash((seed, *addl_seeds)) % 1e8) state = get_rng_state() torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) try: yield finally: set_rng_state(state) class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) self.name = prop.name self.major = prop.major self.minor = prop.minor self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 @staticmethod def pretty_print_cuda_env_list(cuda_env_list): """ Given a list of CudaEnviorments, pretty print them """ num_workers = len(cuda_env_list) center = "CUDA enviroments for all {} workers".format(num_workers) banner_len = 40 - len(center) // 2 first_line = "*" * banner_len + center + "*" * banner_len logger.info(first_line) for r, env in enumerate(cuda_env_list): logger.info( "rank {:3d}: ".format(r) + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + "name = {:40s}".format(env.name) ) logger.info(first_line) def csv_str_list(x): return x.split(",") def eval_str_list(x, type=float): if x is None: return None if isinstance(x, str): x = eval(x) try: return list(map(type, x)) except TypeError: return [type(x)] def eval_str_dict(x, type=dict): if x is None: return None if isinstance(x, str): x = eval(x) return x def eval_bool(x, default=False): if x is None: return default try: return bool(eval(x)) except TypeError: return default def permute_final_dims(tensor: torch.Tensor, inds: List[int]): zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) def flatten_final_dims(t: torch.Tensor, num_dims: int): return t.reshape(t.shape[:-num_dims] + (-1,)) def masked_mean(mask, value, dim, eps=1e-10): mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) def dict_multimap(fn, dicts): first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] if type(v) is dict: new_dict[k] = dict_multimap(fn, all_v) else: new_dict[k] = fn(all_v) return new_dict def one_hot(x, num_classes, dtype=torch.float32): x_one_hot = torch.zeros(*x.shape, num_classes, dtype=dtype, device=x.device) x_one_hot.scatter_(-1, x.long().unsqueeze(-1), 1) return x_one_hot def batched_gather(data, inds, dim=0, num_batch_dims=0): assert dim < 0 or dim - num_batch_dims >= 0 ranges = [] for i, s in enumerate(data.shape[:num_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [ slice(None) for _ in range(len(data.shape) - num_batch_dims) ] remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] def dict_map(fn, dic, leaf_type): new_dict = {} for k, v in dic.items(): if type(v) is dict: new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) return new_dict def tree_map(fn, tree, leaf_type): if isinstance(tree, dict): return dict_map(fn, tree, leaf_type) elif isinstance(tree, list): return [tree_map(fn, x, leaf_type) for x in tree] elif isinstance(tree, tuple): return tuple([tree_map(fn, x, leaf_type) for x in tree]) elif isinstance(tree, leaf_type): try: return fn(tree) except: raise ValueError(f"cannot apply {fn} on {tree}.") else: raise ValueError(f"{type(tree)} not supported") tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) def fp32_to_bf16_sr(t, o): if HAS_FUSED_ROUNDING and t.device.type == "cuda": unicore_fused_rounding.fp32_to_bf16_sr(t, o) else: r = (torch.rand(size=t.size(), device=t.device, dtype=torch.float32) - 0.5) / 256 m, e = torch.frexp(t) t = t + torch.ldexp(r, e) o.data.copy_(t.bfloat16()) def set_jit_fusion_options(): """Set PyTorch JIT layer fusion options.""" # flags required to enable jit fusion kernels TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): # nvfuser torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._debug_set_autodiff_subgraph_inlining(False) else: # legacy pytorch fuser torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True)