Unverified Commit 79a78cae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Get default dtype from pytorch (#300)



* Get default dtype from pytorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6bccc76e
......@@ -962,7 +962,7 @@ class MultiHeadAttention(torch.nn.Module):
fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
......@@ -983,7 +983,7 @@ class MultiHeadAttention(torch.nn.Module):
self.get_rng_state_tracker = get_rng_state_tracker
self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type
......@@ -1008,7 +1008,7 @@ class MultiHeadAttention(torch.nn.Module):
"tp_size": tp_size,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel,
"params_dtype": params_dtype,
"params_dtype": self.params_dtype,
}
qkv_parallel_mode = "column" if set_parallel_mode else None
......
......@@ -4,7 +4,7 @@
"""LayerNorm API"""
import os
from typing import Union, Tuple, Any, Mapping
from typing import Union, Tuple, Any, Mapping, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -78,7 +78,7 @@ class LayerNorm(torch.nn.Module):
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32`
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
......@@ -96,10 +96,11 @@ class LayerNorm(torch.nn.Module):
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
......
......@@ -602,7 +602,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
......@@ -621,7 +621,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
......@@ -632,6 +632,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_split_ag: bool = False,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......
......@@ -915,7 +915,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
......@@ -944,7 +944,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str = "gelu",
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
......@@ -957,6 +957,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.activation = activation
......
......@@ -491,7 +491,7 @@ class Linear(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
......@@ -509,7 +509,7 @@ class Linear(TransformerEngineBaseModule):
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
......@@ -517,6 +517,8 @@ class Linear(TransformerEngineBaseModule):
ub_split_ag: bool = False,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......
......@@ -165,7 +165,7 @@ class TransformerLayer(torch.nn.Module):
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.float32`
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
......@@ -202,7 +202,7 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: torch.dtype = torch.float32,
params_dtype: Optional[torch.dtype] = None,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
......@@ -235,6 +235,7 @@ class TransformerLayer(torch.nn.Module):
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment