Unverified Commit ab54fed2 authored by Tongping Liu's avatar Tongping Liu Committed by GitHub
Browse files

[hotfix] add kwargs for colo_addmm (#2171)

parent a110933d
......@@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
mat2: ColoTensor,
beta: Number = 1,
alpha: Number = 1,
*args) -> ColoTensor:
**kargs) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
......@@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment