Unverified Commit 062a1d0f authored by QQSong's avatar QQSong Committed by GitHub
Browse files

Fix ReplicatedLinear weight loading (#6793)

parent 2eb9f4ff
...@@ -199,12 +199,16 @@ class ReplicatedLinear(LinearBase): ...@@ -199,12 +199,16 @@ class ReplicatedLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
self.params_dtype, self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix) prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0}) set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment