"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "3dbbf83f1c46ae2a3b2947e1a5925c2b8af9f7b1"
Unverified Commit 763dc325 authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[TP] Add gather_out arg to Linear (#541)

parent 8c90d4df
import math
import inspect
from typing import Callable
from colossalai.utils import get_current_device
......@@ -78,15 +79,19 @@ class Linear(nn.Module):
if self.layer.bias is not None:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
self.layer = linear_cls(
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment