Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0dd4e2bb
Unverified
Commit
0dd4e2bb
authored
Jun 27, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 27, 2022
Browse files
[Tensor] rename some APIs in TensorSpec and Polish view unittest (#1176)
parent
dd042090
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
26 additions
and
18 deletions
+26
-18
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+3
-3
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+3
-2
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+1
-1
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+3
-3
colossalai/nn/_ops/loss.py
colossalai/nn/_ops/loss.py
+1
-1
colossalai/tensor/chunk.py
colossalai/tensor/chunk.py
+3
-1
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+1
-1
colossalai/tensor/tensor_spec.py
colossalai/tensor/tensor_spec.py
+2
-2
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+9
-4
No files found.
colossalai/nn/_ops/addmm.py
View file @
0dd4e2bb
...
...
@@ -72,10 +72,10 @@ def colo_addmm(input_tensor: GeneralTensor,
assert
input_tensor
.
tensor_spec
.
is_replicate
(),
'Invalid input spec for native addmm op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
))
elif
mat2
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
mat2
.
tensor_spec
.
is_
1D_
row
()
and
input_tensor
.
tensor_spec
.
is_replicate
():
if
mat2
.
tensor_spec
.
is_
shard_1d
row
()
and
input_tensor
.
tensor_spec
.
is_replicate
():
mode
=
'row'
elif
mat2
.
tensor_spec
.
is_
1D_
col
()
and
(
input_tensor
.
tensor_spec
.
is_
1D_
col
()
or
input_tensor
.
tensor_spec
.
is_
1D_
row
()):
elif
mat2
.
tensor_spec
.
is_
shard_1d
col
()
and
(
input_tensor
.
tensor_spec
.
is_
shard_1d
col
()
or
input_tensor
.
tensor_spec
.
is_
shard_1d
row
()):
mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/embedding.py
View file @
0dd4e2bb
...
...
@@ -32,6 +32,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
compute_spec
=
weight
.
tensor_spec
.
compute_spec
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
else
:
...
...
@@ -125,9 +126,9 @@ def colo_embedding(input_tensor: GeneralTensor,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
))
elif
weight
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_spec
.
is_
1D_
row
():
if
weight
.
tensor_spec
.
is_
shard_1d
row
():
mode
=
'row'
elif
weight
.
tensor_spec
.
is_
1D_
col
():
elif
weight
.
tensor_spec
.
is_
shard_1d
col
():
mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
0dd4e2bb
...
...
@@ -104,7 +104,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
))
elif
weight
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_spec
.
is_
1D_
col
():
if
weight
.
tensor_spec
.
is_
shard_1d
col
():
tp_mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/linear.py
View file @
0dd4e2bb
...
...
@@ -71,10 +71,10 @@ def colo_linear_imp(input_tensor: GeneralTensor,
assert
bias
is
None
or
bias
.
tensor_spec
.
is_replicate
(),
'Invalid bias spec for native Linear op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_spec
.
is_
1D_
col
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_replicate
()):
if
weight
.
tensor_spec
.
is_
shard_1d
col
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_replicate
()):
mode
=
'row'
elif
weight
.
tensor_spec
.
is_
1D_
row
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_
1D_
row
()
or
bias
.
tensor_spec
.
is_
1D_
col
()):
elif
weight
.
tensor_spec
.
is_
shard_1d
row
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_
shard_1d
row
()
or
bias
.
tensor_spec
.
is_
shard_1d
col
()):
mode
=
'col'
else
:
raise
RuntimeError
(
f
"the weight or bias tensor spec is not valid, weight
{
weight
.
tensor_spec
}
, bias
{
bias
}
"
)
...
...
colossalai/nn/_ops/loss.py
View file @
0dd4e2bb
...
...
@@ -29,7 +29,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing
=
label_smoothing
)
return
ColoTensor
.
from_torch_tensor
(
output
)
elif
input_tensor
.
has_compute_spec
():
# Single Model Parallel Applied
if
input_tensor
.
tensor_spec
.
is_
1D_
col
():
if
input_tensor
.
tensor_spec
.
is_
shard_1d
col
():
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
)
return
ColoTensor
.
from_torch_tensor
(
output
)
else
:
...
...
colossalai/tensor/chunk.py
View file @
0dd4e2bb
...
...
@@ -116,6 +116,7 @@ class Chunk:
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
...
...
@@ -131,6 +132,7 @@ class Chunk:
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
...
...
@@ -228,7 +230,7 @@ class Chunk:
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
view
(
-
1
))
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
(
))
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
...
...
colossalai/tensor/distspec.py
View file @
0dd4e2bb
...
...
@@ -54,5 +54,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
assert
process_group
is
not
None
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
len
(
dims
)
==
len
(
num_partitions
)
assert
prod
(
num_partitions
)
==
process_group
.
size
()
assert
prod
(
num_partitions
)
==
process_group
.
size
()
,
f
"
{
num_partitions
}
{
process_group
.
size
()
}
"
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
process_group
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
colossalai/tensor/tensor_spec.py
View file @
0dd4e2bb
...
...
@@ -32,11 +32,11 @@ class TensorSpec(object):
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
or
(
self
.
dist_spec
.
process_group
.
size
()
==
1
)
def
is_
1D_
col
(
self
):
def
is_
shard_1d
col
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
-
1
def
is_
1D_
row
(
self
):
def
is_
shard_1d
row
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
...
...
tests/test_tensor/test_tensor.py
View file @
0dd4e2bb
...
...
@@ -63,13 +63,19 @@ def test_operand():
def
_run_view
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
,
TensorSpec
(
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
2
])))
t_ref
,
TensorSpec
(
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
world_size
])))
assert
t
.
size
()[
0
]
==
4
*
world_size
assert
t
.
size
(
1
)
==
5
assert
t
.
size
()
==
torch
.
Size
([
4
*
world_size
,
5
])
t
.
view_base
(
4
*
5
)
assert
t
.
tensor_spec
.
dist_spec
.
placement
.
value
==
's'
t
=
t
.
view
(
4
*
5
*
world_size
)
assert
t
.
tensor_spec
.
dist_spec
.
placement
.
value
==
'r'
assert
t
.
shape
==
torch
.
Size
([
4
*
5
*
world_size
])
...
...
@@ -100,11 +106,10 @@ def run_dist_tests(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
_
test_dist_cases
(
world_size
):
def
test_dist_cases
(
world_size
):
run_func
=
partial
(
run_dist_tests
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
# _test_dist_init(4)
_test_dist_cases
(
2
)
test_dist_cases
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment