Commit be8de1b3 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed shared weight attribute for fp32.

parent 2c9ed910
...@@ -237,6 +237,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -237,6 +237,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_fp32_params_this_group.append(shard_model_param) shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes( mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else: else:
raise TypeError('Wrapped parameters must be one of ' raise TypeError('Wrapped parameters must be one of '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment