"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "20e2de7d8bba9f3735c9a67b000f04853927a0f8"
Unverified Commit 79a78cae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Get default dtype from pytorch (#300)



* Get default dtype from pytorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6bccc76e
...@@ -962,7 +962,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -962,7 +962,7 @@ class MultiHeadAttention(torch.nn.Module):
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
input_layernorm: bool = False, input_layernorm: bool = False,
attention_type: str = "self", attention_type: str = "self",
...@@ -983,7 +983,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -983,7 +983,7 @@ class MultiHeadAttention(torch.nn.Module):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.tp_group = tp_group self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.init_method = init_method self.init_method = init_method
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
...@@ -1008,7 +1008,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1008,7 +1008,7 @@ class MultiHeadAttention(torch.nn.Module):
"tp_size": tp_size, "tp_size": tp_size,
"get_rng_state_tracker": get_rng_state_tracker, "get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel, "sequence_parallel": sequence_parallel,
"params_dtype": params_dtype, "params_dtype": self.params_dtype,
} }
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""LayerNorm API""" """LayerNorm API"""
import os import os
from typing import Union, Tuple, Any, Mapping from typing import Union, Tuple, Any, Mapping, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -78,7 +78,7 @@ class LayerNorm(torch.nn.Module): ...@@ -78,7 +78,7 @@ class LayerNorm(torch.nn.Module):
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism. if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32` params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
...@@ -96,10 +96,11 @@ class LayerNorm(torch.nn.Module): ...@@ -96,10 +96,11 @@ class LayerNorm(torch.nn.Module):
hidden_size: int, hidden_size: int,
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps self.eps = eps
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter( self.weight = Parameter(
......
...@@ -602,7 +602,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -602,7 +602,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32` params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
...@@ -621,7 +621,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -621,7 +621,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
init_method: Optional[Callable] = None, init_method: Optional[Callable] = None,
bias: bool = True, bias: bool = True,
return_bias: bool = False, return_bias: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
...@@ -632,6 +632,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -632,6 +632,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_split_ag: bool = False, ub_split_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......
...@@ -915,7 +915,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -915,7 +915,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32` params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
...@@ -944,7 +944,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -944,7 +944,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str = "gelu", activation : str = "gelu",
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
...@@ -957,6 +957,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -957,6 +957,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias self.use_bias = bias
self.activation = activation self.activation = activation
......
...@@ -491,7 +491,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -491,7 +491,7 @@ class Linear(TransformerEngineBaseModule):
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32` params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
...@@ -509,7 +509,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -509,7 +509,7 @@ class Linear(TransformerEngineBaseModule):
init_method: Optional[Callable] = None, init_method: Optional[Callable] = None,
bias: bool = True, bias: bool = True,
return_bias: bool = False, return_bias: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Tuple[str, ...]] = None,
...@@ -517,6 +517,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -517,6 +517,8 @@ class Linear(TransformerEngineBaseModule):
ub_split_ag: bool = False, ub_split_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......
...@@ -165,7 +165,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -165,7 +165,7 @@ class TransformerLayer(torch.nn.Module):
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.float32` params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
...@@ -202,7 +202,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -202,7 +202,7 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
params_dtype: torch.dtype = torch.float32, params_dtype: Optional[torch.dtype] = None,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
...@@ -235,6 +235,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -235,6 +235,7 @@ class TransformerLayer(torch.nn.Module):
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1"))) ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1"))) ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1"))) ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment