Unverified Commit 2451c8d1 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add CPU initialization method (#368)



* CPU initialization
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix default value
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change API and add to RMSNorm
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2e0bfbd9
......@@ -1217,6 +1217,7 @@ class MultiHeadAttention(torch.nn.Module):
ub_split_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
self.layer_number = layer_number
......@@ -1261,6 +1262,7 @@ class MultiHeadAttention(torch.nn.Module):
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel,
"params_dtype": self.params_dtype,
"device": device,
}
qkv_parallel_mode = "column" if set_parallel_mode else None
......
......@@ -104,6 +104,10 @@ class LayerNorm(torch.nn.Module):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
"""
def __init__(
......@@ -113,6 +117,7 @@ class LayerNorm(torch.nn.Module):
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
......@@ -121,14 +126,14 @@ class LayerNorm(torch.nn.Module):
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
device=device,
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
device=device,
dtype=params_dtype,
)
)
......
......@@ -548,6 +548,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
......@@ -608,6 +612,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
......@@ -666,20 +671,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(in_features, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(in_features, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
......@@ -688,8 +685,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
......@@ -702,11 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
device=device,
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -743,8 +738,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
......@@ -874,6 +874,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
......@@ -944,6 +948,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
......@@ -990,20 +995,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
# LN init
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
......@@ -1016,12 +1013,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_output_features = self.size_per_partition
# FC1 init
self.fc1_weight = Parameter(
torch.empty(
fc1_output_features,
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
......@@ -1035,28 +1027,18 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.use_bias:
self.fc1_bias = Parameter(
torch.empty(
fc1_output_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(fc1_output_features, device=device, dtype=params_dtype)
)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.fc1_bias.zero_()
# FC2 init
self.fc2_weight = Parameter(
torch.empty(
hidden_size,
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
......@@ -1070,13 +1052,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.use_bias:
self.fc2_bias = Parameter(
torch.empty(
hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
)
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
......@@ -466,6 +466,10 @@ class Linear(TransformerEngineBaseModule):
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
......@@ -521,6 +525,7 @@ class Linear(TransformerEngineBaseModule):
parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
......@@ -574,8 +579,7 @@ class Linear(TransformerEngineBaseModule):
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
......@@ -586,13 +590,9 @@ class Linear(TransformerEngineBaseModule):
)
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
with torch.no_grad():
self.bias_tensor.zero_()
......@@ -629,8 +629,7 @@ class Linear(TransformerEngineBaseModule):
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))
setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
......
......@@ -114,6 +114,10 @@ class RMSNorm(torch.nn.Module):
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
"""
def __init__(
......@@ -123,6 +127,7 @@ class RMSNorm(torch.nn.Module):
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
......@@ -131,7 +136,7 @@ class RMSNorm(torch.nn.Module):
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
device=device,
dtype=params_dtype,
)
)
......
......@@ -149,6 +149,10 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
......@@ -233,6 +237,7 @@ class TransformerLayer(torch.nn.Module):
bias: bool = True,
activation: str = 'gelu',
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
) -> None:
super().__init__()
......@@ -326,6 +331,7 @@ class TransformerLayer(torch.nn.Module):
attention_type="self",
bias=bias,
normalization=normalization,
device=device,
)
if layer_type == "decoder":
......@@ -337,6 +343,7 @@ class TransformerLayer(torch.nn.Module):
attention_type="cross",
bias=bias,
normalization=normalization,
device=device,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
......@@ -369,6 +376,7 @@ class TransformerLayer(torch.nn.Module):
ub_split_ag=ub_split_ag,
activation=activation,
normalization=normalization,
device=device,
)
self.hidden_dropout = hidden_dropout
......@@ -402,7 +410,8 @@ class TransformerLayer(torch.nn.Module):
eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
zero_centered_gamma=zero_centered_gamma
zero_centered_gamma=zero_centered_gamma,
device=device,
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment