Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
196d4696
Commit
196d4696
authored
Mar 28, 2023
by
Tong Li
Committed by
binmakeswell
Mar 29, 2023
Browse files
[NFC] polish colossalai/nn/_ops/addmm.py code style (#3274)
parent
4b954649
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
7 deletions
+11
-7
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+11
-7
No files found.
colossalai/nn/_ops/addmm.py
View file @
196d4696
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ReplicaSpec
,
ShardSpec
,
distspec
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
distspec
,
ColoTensorSpec
,
ShardSpec
,
ReplicaSpec
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
,
reduce_grad
,
reduce_input
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
reduce_input
,
reduce_grad
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
...
@@ -69,9 +69,13 @@ def colo_addmm(input_tensor: GeneralTensor,
...
@@ -69,9 +69,13 @@ def colo_addmm(input_tensor: GeneralTensor,
if
not
mat2
.
has_compute_spec
():
# No Model Parallel Applied
if
not
mat2
.
has_compute_spec
():
# No Model Parallel Applied
assert
mat2
.
is_replicate
(),
'Invalid mat2 spec for native addmm op'
assert
mat2
.
is_replicate
(),
'Invalid mat2 spec for native addmm op'
assert
input_tensor
.
is_replicate
(),
'Invalid input spec for native addmm op'
assert
input_tensor
.
is_replicate
(),
'Invalid input spec for native addmm op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
=
torch
.
addmm
(
input_tensor
,
tensor
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
,
**
kargs
),
mat1
,
spec
=
ColoTensorSpec
(
mat2
.
get_process_group
()))
mat2
,
beta
=
beta
,
alpha
=
alpha
,
**
kargs
),
spec
=
ColoTensorSpec
(
mat2
.
get_process_group
()))
elif
mat2
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
elif
mat2
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
mat2
.
is_shard_1drow
()
and
input_tensor
.
is_replicate
():
if
mat2
.
is_shard_1drow
()
and
input_tensor
.
is_replicate
():
mode
=
'row'
mode
=
'row'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment