# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Enums for e2e transformer""" import torch import transformer_engine_extensions as tex """ This is a map: torch.dtype -> int Used for passing dtypes into cuda extension. Has one to one mapping with enum in transformer_engine.h """ TE_DType = { torch.int8: tex.DType.kByte, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, torch.bfloat16: tex.DType.kBFloat16, } AttnMaskTypes = ("causal", "padding") AttnTypes = ("self", "cross") LayerTypes = ("encoder", "decoder") GemmParallelModes = ("row", "column", None) dist_group_type = torch._C._distributed_c10d.ProcessGroup