Unverified Commit 90979c38 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[FIX] Don't initialize parameter by default (#1067)

parent e21d7687
...@@ -83,7 +83,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -83,7 +83,7 @@ class VocabParallelEmbedding(torch.nn.Module):
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
params_dtype: torch.dtype=None, params_dtype: torch.dtype=None,
use_cpu_initialization: bool=False, use_cpu_initialization: bool=False,
perform_initialization: bool=True): perform_initialization: bool=False):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization assert not perform_initialization
assert not use_cpu_initialization assert not use_cpu_initialization
...@@ -113,7 +113,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -113,7 +113,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype)) device=torch.cuda.current_device(), dtype=params_dtype))
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
...@@ -172,7 +172,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -172,7 +172,7 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add=False, skip_bias_add=False,
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=False,
quant_config=None, quant_config=None,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
...@@ -288,7 +288,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -288,7 +288,7 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add=False, skip_bias_add=False,
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=False,
reduce_results=True, reduce_results=True,
quant_config=None, quant_config=None,
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment