Commit e10760ea authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Making --lazy_mpu_init act as _USE_CPU_INITIALIZATION

parent d4bb6cb0
...@@ -141,7 +141,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -141,7 +141,7 @@ class VocabParallelEmbedding(torch.nn.Module):
# Allocate weights and initialize. # Allocate weights and initialize.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -217,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -217,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size, self.input_size,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -233,7 +233,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -233,7 +233,7 @@ class ColumnParallelLinear(torch.nn.Module):
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if _USE_CPU_INITIALIZATION: if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=args.params_dtype)) self.output_size_per_partition, dtype=args.params_dtype))
else: else:
...@@ -311,7 +311,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -311,7 +311,7 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
args = get_args() args = get_args()
if _USE_CPU_INITIALIZATION: if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
self.weight = Parameter(torch.empty(self.output_size, self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition, self.input_size_per_partition,
dtype=args.params_dtype)) dtype=args.params_dtype))
...@@ -326,7 +326,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -326,7 +326,7 @@ class RowParallelLinear(torch.nn.Module):
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride) partition_dim=1, stride=stride)
if bias: if bias:
if _USE_CPU_INITIALIZATION: if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
self.bias = Parameter(torch.empty(self.output_size, self.bias = Parameter(torch.empty(self.output_size,
dtype=args.params_dtype)) dtype=args.params_dtype))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment