Commit bdd47d64 authored by Jared Casper's avatar Jared Casper
Browse files

Address comments, fix argument bug.

parent 78066ab0
...@@ -503,7 +503,7 @@ def _add_distributed_args(parser): ...@@ -503,7 +503,7 @@ def _add_distributed_args(parser):
' and returns function to complete it instead.' ' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.' 'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' ) 'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', type=bool, required=False, group.add_argument('--use-cpu-initialization', action='store_true', default=None,
help='If set, affine parallel weights initialization uses CPU' ) help='If set, affine parallel weights initialization uses CPU' )
return parser return parser
......
...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.partition_stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
...@@ -199,15 +199,16 @@ def main(): ...@@ -199,15 +199,16 @@ def main():
'no_load_rng': True, 'no_load_rng': True,
'save_interval': 1}) 'save_interval': 1})
args = get_args() args = get_args()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
if args.pipeline_model_parallel_size > 1: if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.") print("Checkpoints with pipeline model parallelism are not currently supported.")
exit() exit()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...') print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size)) print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load)) print(' > checkpoint path: {}'.format(args.load))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment