Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9c2feb2f
Unverified
Commit
9c2feb2f
authored
Sep 12, 2023
by
digger yu
Committed by
GitHub
Sep 12, 2023
Browse files
fix some typo with colossalai/device colossalai/tensor/ etc. (#4171)
Co-authored-by:
flybird11111
<
1829166702@qq.com
>
parent
d8ceeac1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
12 deletions
+12
-12
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+6
-6
colossalai/tensor/d_tensor/comm_spec.py
colossalai/tensor/d_tensor/comm_spec.py
+1
-1
colossalai/tensor/shape_consistency.py
colossalai/tensor/shape_consistency.py
+2
-2
tests/kit/model_zoo/transformers/t5.py
tests/kit/model_zoo/transformers/t5.py
+1
-1
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+1
-1
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+1
-1
No files found.
colossalai/device/device_mesh.py
View file @
9c2feb2f
...
@@ -59,7 +59,7 @@ class DeviceMesh:
...
@@ -59,7 +59,7 @@ class DeviceMesh:
# 2. directly supply the logical mesh id
# 2. directly supply the logical mesh id
assert
mesh_shape
is
None
or
logical_mesh_id
is
None
,
\
assert
mesh_shape
is
None
or
logical_mesh_id
is
None
,
\
"Only one of mesh_shape and logical_mesh_id can be specified."
\
"Only one of mesh_shape and logical_mesh_id can be specified."
\
"Logical mesh IDs are obtained from either mesh_shape + phy
i
scal_mesh_id or directly from the user-supplied logical_mesh_id"
"Logical mesh IDs are obtained from either mesh_shape + phys
i
cal_mesh_id or directly from the user-supplied logical_mesh_id"
if
logical_mesh_id
is
None
:
if
logical_mesh_id
is
None
:
self
.
_mesh_shape
=
mesh_shape
self
.
_mesh_shape
=
mesh_shape
...
@@ -74,7 +74,7 @@ class DeviceMesh:
...
@@ -74,7 +74,7 @@ class DeviceMesh:
assert
torch
.
equal
(
torch
.
unique
(
self
.
_physical_mesh_id
),
torch
.
unique
(
self
.
logical_mesh_id
)),
\
assert
torch
.
equal
(
torch
.
unique
(
self
.
_physical_mesh_id
),
torch
.
unique
(
self
.
logical_mesh_id
)),
\
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert
torch
.
unique
(
self
.
_physical_mesh_id
).
numel
()
==
self
.
_physical_mesh_id
.
numel
(),
\
assert
torch
.
unique
(
self
.
_physical_mesh_id
).
numel
()
==
self
.
_physical_mesh_id
.
numel
(),
\
"Found duplicate IDs in the phy
i
scal_mesh_id and this is not allowed, please check your physical_mesh_id again."
"Found duplicate IDs in the phys
i
cal_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert
torch
.
unique
(
self
.
logical_mesh_id
).
numel
()
==
self
.
logical_mesh_id
.
numel
(),
\
assert
torch
.
unique
(
self
.
logical_mesh_id
).
numel
()
==
self
.
logical_mesh_id
.
numel
(),
\
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
...
@@ -118,7 +118,7 @@ class DeviceMesh:
...
@@ -118,7 +118,7 @@ class DeviceMesh:
self
.
_global_rank_of_current_process
=
None
self
.
_global_rank_of_current_process
=
None
self
.
_is_initialized
=
False
self
.
_is_initialized
=
False
# attribute used to in
i
dicate whether this object
d
# attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
# such get_process_group as no global rank information
...
@@ -395,7 +395,7 @@ class DeviceMesh:
...
@@ -395,7 +395,7 @@ class DeviceMesh:
Example:
Example:
```python
```python
s
physical_mesh_id = torch.arange(0, 16)
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
mesh_shape = (4, 4)
# logical mesh will look like
# logical mesh will look like
...
@@ -438,7 +438,7 @@ class DeviceMesh:
...
@@ -438,7 +438,7 @@ class DeviceMesh:
# the _local_rank refers to the local rank of the current process
# the _local_rank refers to the local rank of the current process
for
_local_rank
in
range
(
self
.
logical_mesh_id
.
shape
[
dim
]):
for
_local_rank
in
range
(
self
.
logical_mesh_id
.
shape
[
dim
]):
# if this dimension is not init
a
ilized yet,
# if this dimension is not initi
a
lized yet,
# initialize it with an empty array
# initialize it with an empty array
if
dim
not
in
processes_in_the_same_process_group
:
if
dim
not
in
processes_in_the_same_process_group
:
processes_in_the_same_process_group
[
dim
]
=
[]
processes_in_the_same_process_group
[
dim
]
=
[]
...
@@ -447,7 +447,7 @@ class DeviceMesh:
...
@@ -447,7 +447,7 @@ class DeviceMesh:
process_coordinates
=
self
.
_global_to_local_rank_mapping
[
global_rank
].
copy
()
process_coordinates
=
self
.
_global_to_local_rank_mapping
[
global_rank
].
copy
()
# replace the local rank in the given dimension with the
# replace the local rank in the given dimension with the
# l
c
oal rank of the current process iterated
# lo
c
al rank of the current process iterated
process_coordinates
[
dim
]
=
_local_rank
process_coordinates
[
dim
]
=
_local_rank
processes_in_the_same_process_group
[
dim
].
append
(
process_coordinates
)
processes_in_the_same_process_group
[
dim
].
append
(
process_coordinates
)
...
...
colossalai/tensor/d_tensor/comm_spec.py
View file @
9c2feb2f
...
@@ -28,7 +28,7 @@ class CommSpec:
...
@@ -28,7 +28,7 @@ class CommSpec:
to determine the buffer shape, and logical_process_axis
to determine the buffer shape, and logical_process_axis
Argument:
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
comm_pattern(CollectiveCommPattern): de
s
cribe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
...
...
colossalai/tensor/shape_consistency.py
View file @
9c2feb2f
...
@@ -339,7 +339,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
...
@@ -339,7 +339,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
RS01 -> RR
RS01 -> RR
'''
'''
valid_spec_dict
=
{}
valid_spec_dict
=
{}
comm_pat
h
ern
=
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
comm_pat
t
ern
=
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
tensor_dims
=
len
(
source_spec
.
entire_shape
)
tensor_dims
=
len
(
source_spec
.
entire_shape
)
for
f_index
in
range
(
tensor_dims
-
1
):
for
f_index
in
range
(
tensor_dims
-
1
):
for
b_index
in
range
(
f_index
+
1
,
tensor_dims
):
for
b_index
in
range
(
f_index
+
1
,
tensor_dims
):
...
@@ -362,7 +362,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
...
@@ -362,7 +362,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
b_target_pair
=
(
b_index
,
[])
b_target_pair
=
(
b_index
,
[])
gather_dim
,
logical_process_axes
=
mix_gather_simulator
(
f_target_pair
,
b_target_pair
)
gather_dim
,
logical_process_axes
=
mix_gather_simulator
(
f_target_pair
,
b_target_pair
)
comm_spec
=
CommSpec
(
comm_pat
h
ern
,
comm_spec
=
CommSpec
(
comm_pat
t
ern
,
sharding_spec
=
source_spec
,
sharding_spec
=
source_spec
,
gather_dim
=
gather_dim
,
gather_dim
=
gather_dim
,
logical_process_axis
=
logical_process_axes
,
logical_process_axis
=
logical_process_axes
,
...
...
tests/kit/model_zoo/transformers/t5.py
View file @
9c2feb2f
...
@@ -43,7 +43,7 @@ def data_gen_for_t5_model():
...
@@ -43,7 +43,7 @@ def data_gen_for_t5_model():
# output transform function
# output transform function
output_transform_fn
=
lambda
x
:
x
output_transform_fn
=
lambda
x
:
x
# define loss func
i
ton
# define loss funct
i
on
loss_fn_for_t5_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_t5_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_encoder_only
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_encoder_only
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_conditional_generation
=
lambda
x
:
x
.
loss
loss_fn_for_conditional_generation
=
lambda
x
:
x
.
loss
...
...
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
View file @
9c2feb2f
...
@@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
...
@@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
model
=
DummyModel
()
model
=
DummyModel
()
criterion
=
lambda
x
:
x
.
mean
()
criterion
=
lambda
x
:
x
.
mean
()
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
# create a custom da
se
tset with 0 to 10
# create a custom dat
a
set with 0 to 10
dataset
=
torch
.
arange
(
0
,
10
)
dataset
=
torch
.
arange
(
0
,
10
)
train_dataloader
=
plugin
.
prepare_dataloader
(
dataset
,
batch_size
=
2
)
train_dataloader
=
plugin
.
prepare_dataloader
(
dataset
,
batch_size
=
2
)
model
,
optimizer
,
criterion
,
train_dataloader
,
_
=
booster
.
boost
(
model
,
model
,
optimizer
,
criterion
,
train_dataloader
,
_
=
booster
.
boost
(
model
,
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
9c2feb2f
...
@@ -15,7 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
...
@@ -15,7 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
# test ba
i
sc fsdp function
# test bas
i
c fsdp function
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
booster
=
Booster
(
plugin
=
plugin
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment