"examples/python/vscode:/vscode.git/clone" did not exist on "8588e33a464d9f82d6ad93a433590a3bc3ff92de"
Unverified Commit ab54fed2 authored by Tongping Liu's avatar Tongping Liu Committed by GitHub
Browse files

[hotfix] add kwargs for colo_addmm (#2171)

parent a110933d
...@@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor, ...@@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
mat2: ColoTensor, mat2: ColoTensor,
beta: Number = 1, beta: Number = 1,
alpha: Number = 1, alpha: Number = 1,
*args) -> ColoTensor: **kargs) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear. This method computes a linear.
""" """
...@@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor, ...@@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor( ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha), tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
spec=ColoTensorSpec(mat2.get_process_group())) spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate(): if mat2.is_shard_1drow() and input_tensor.is_replicate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment