"vscode:/vscode.git/clone" did not exist on "e57c0c301c7fbe98d9c96a86890443cd83fc5bb9"
Unverified Commit 2451c8d1 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add CPU initialization method (#368)



* CPU initialization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix default value
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change API and add to RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2e0bfbd9
...@@ -1217,6 +1217,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1217,6 +1217,7 @@ class MultiHeadAttention(torch.nn.Module):
ub_split_ag: bool = False, ub_split_ag: bool = False,
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = layer_number self.layer_number = layer_number
...@@ -1261,6 +1262,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1261,6 +1262,7 @@ class MultiHeadAttention(torch.nn.Module):
"get_rng_state_tracker": get_rng_state_tracker, "get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel, "sequence_parallel": sequence_parallel,
"params_dtype": self.params_dtype, "params_dtype": self.params_dtype,
"device": device,
} }
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
......
...@@ -104,6 +104,10 @@ class LayerNorm(torch.nn.Module): ...@@ -104,6 +104,10 @@ class LayerNorm(torch.nn.Module):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta (1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
""" """
def __init__( def __init__(
...@@ -113,6 +117,7 @@ class LayerNorm(torch.nn.Module): ...@@ -113,6 +117,7 @@ class LayerNorm(torch.nn.Module):
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
...@@ -121,14 +126,14 @@ class LayerNorm(torch.nn.Module): ...@@ -121,14 +126,14 @@ class LayerNorm(torch.nn.Module):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
device=torch.cuda.current_device(), device=device,
dtype=params_dtype, dtype=params_dtype,
) )
) )
self.bias = Parameter( self.bias = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
device=torch.cuda.current_device(), device=device,
dtype=params_dtype, dtype=params_dtype,
) )
) )
......
...@@ -548,6 +548,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -548,6 +548,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta (1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -608,6 +612,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -608,6 +612,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -666,20 +671,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -666,20 +671,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.eps = eps self.eps = eps
self.layer_norm_weight = Parameter( self.layer_norm_weight = Parameter(
torch.empty( torch.empty(in_features, device=device, dtype=params_dtype)
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( self.layer_norm_bias = Parameter(
torch.empty( torch.empty(in_features, device=device, dtype=params_dtype)
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else: else:
...@@ -688,8 +685,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -688,8 +685,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight_tensor = torch.empty( self.weight_tensor = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=torch.cuda.current_device(), device=device, dtype=params_dtype)
dtype=params_dtype)
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.weight_tensor, self.weight_tensor,
...@@ -702,11 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -702,11 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty( self.bias_tensor = torch.empty(
self.out_features, self.out_features,
device=torch.cuda.current_device(), device=device,
dtype=params_dtype) dtype=params_dtype)
else: else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
device=torch.cuda.current_device())
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
...@@ -743,8 +738,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -743,8 +738,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else: else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
device=torch.cuda.current_device()))
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
...@@ -874,6 +874,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -874,6 +874,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta (1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -944,6 +948,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -944,6 +948,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -990,20 +995,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -990,20 +995,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
# LN init # LN init
self.eps = eps self.eps = eps
self.layer_norm_weight = Parameter( self.layer_norm_weight = Parameter(
torch.empty( torch.empty(hidden_size, device=device, dtype=params_dtype)
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( self.layer_norm_bias = Parameter(
torch.empty( torch.empty(hidden_size, device=device, dtype=params_dtype)
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else: else:
...@@ -1016,12 +1013,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1016,12 +1013,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
# FC1 init # FC1 init
self.fc1_weight = Parameter( self.fc1_weight = Parameter(
torch.empty( torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype)
fc1_output_features,
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
self.fp8_weight_shapes.append(self.fc1_weight.shape) self.fp8_weight_shapes.append(self.fc1_weight.shape)
...@@ -1035,28 +1027,18 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1035,28 +1027,18 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.use_bias: if self.use_bias:
self.fc1_bias = Parameter( self.fc1_bias = Parameter(
torch.empty( torch.empty(fc1_output_features, device=device, dtype=params_dtype)
fc1_output_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else: else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
device=torch.cuda.current_device())
with torch.no_grad(): with torch.no_grad():
self.fc1_bias.zero_() self.fc1_bias.zero_()
# FC2 init # FC2 init
self.fc2_weight = Parameter( self.fc2_weight = Parameter(
torch.empty( torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
hidden_size,
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
self.fp8_weight_shapes.append(self.fc2_weight.shape) self.fp8_weight_shapes.append(self.fc2_weight.shape)
...@@ -1070,13 +1052,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1070,13 +1052,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.use_bias: if self.use_bias:
self.fc2_bias = Parameter( self.fc2_bias = Parameter(
torch.empty( torch.empty(hidden_size, device=device, dtype=params_dtype)
hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
)
) )
else: else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
device=torch.cuda.current_device())
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
...@@ -466,6 +466,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -466,6 +466,10 @@ class Linear(TransformerEngineBaseModule):
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters. and the strings contained are the names of the split parameters.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -521,6 +525,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -521,6 +525,7 @@ class Linear(TransformerEngineBaseModule):
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -574,8 +579,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -574,8 +579,7 @@ class Linear(TransformerEngineBaseModule):
self.weight_tensor = torch.empty( self.weight_tensor = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=torch.cuda.current_device(), device=device, dtype=params_dtype)
dtype=params_dtype)
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.weight_tensor, self.weight_tensor,
...@@ -586,13 +590,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -586,13 +590,9 @@ class Linear(TransformerEngineBaseModule):
) )
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty( self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else: else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
device=torch.cuda.current_device())
with torch.no_grad(): with torch.no_grad():
self.bias_tensor.zero_() self.bias_tensor.zero_()
...@@ -629,8 +629,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -629,8 +629,7 @@ class Linear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
else: else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
device=torch.cuda.current_device()))
if parallel_mode == "column": if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
...@@ -114,6 +114,10 @@ class RMSNorm(torch.nn.Module): ...@@ -114,6 +114,10 @@ class RMSNorm(torch.nn.Module):
.. math:: .. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
""" """
def __init__( def __init__(
...@@ -123,6 +127,7 @@ class RMSNorm(torch.nn.Module): ...@@ -123,6 +127,7 @@ class RMSNorm(torch.nn.Module):
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
...@@ -131,7 +136,7 @@ class RMSNorm(torch.nn.Module): ...@@ -131,7 +136,7 @@ class RMSNorm(torch.nn.Module):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
device=torch.cuda.current_device(), device=device,
dtype=params_dtype, dtype=params_dtype,
) )
) )
......
...@@ -149,6 +149,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -149,6 +149,10 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'. Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -233,6 +237,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -233,6 +237,7 @@ class TransformerLayer(torch.nn.Module):
bias: bool = True, bias: bool = True,
activation: str = 'gelu', activation: str = 'gelu',
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -326,6 +331,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -326,6 +331,7 @@ class TransformerLayer(torch.nn.Module):
attention_type="self", attention_type="self",
bias=bias, bias=bias,
normalization=normalization, normalization=normalization,
device=device,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -337,6 +343,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -337,6 +343,7 @@ class TransformerLayer(torch.nn.Module):
attention_type="cross", attention_type="cross",
bias=bias, bias=bias,
normalization=normalization, normalization=normalization,
device=device,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -369,6 +376,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -369,6 +376,7 @@ class TransformerLayer(torch.nn.Module):
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -402,7 +410,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -402,7 +410,8 @@ class TransformerLayer(torch.nn.Module):
eps=layernorm_epsilon, eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype, params_dtype=params_dtype,
zero_centered_gamma=zero_centered_gamma zero_centered_gamma=zero_centered_gamma,
device=device,
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment