Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9c2feb2f
Unverified
Commit
9c2feb2f
authored
Sep 12, 2023
by
digger yu
Committed by
GitHub
Sep 12, 2023
Browse files
fix some typo with colossalai/device colossalai/tensor/ etc. (#4171)
Co-authored-by:
flybird11111
<
1829166702@qq.com
>
parent
d8ceeac1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
12 deletions
+12
-12
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+6
-6
colossalai/tensor/d_tensor/comm_spec.py
colossalai/tensor/d_tensor/comm_spec.py
+1
-1
colossalai/tensor/shape_consistency.py
colossalai/tensor/shape_consistency.py
+2
-2
tests/kit/model_zoo/transformers/t5.py
tests/kit/model_zoo/transformers/t5.py
+1
-1
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+1
-1
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+1
-1
No files found.
colossalai/device/device_mesh.py
View file @
9c2feb2f
...
...
@@ -59,7 +59,7 @@ class DeviceMesh:
# 2. directly supply the logical mesh id
assert
mesh_shape
is
None
or
logical_mesh_id
is
None
,
\
"Only one of mesh_shape and logical_mesh_id can be specified."
\
"Logical mesh IDs are obtained from either mesh_shape + phy
i
scal_mesh_id or directly from the user-supplied logical_mesh_id"
"Logical mesh IDs are obtained from either mesh_shape + phys
i
cal_mesh_id or directly from the user-supplied logical_mesh_id"
if
logical_mesh_id
is
None
:
self
.
_mesh_shape
=
mesh_shape
...
...
@@ -74,7 +74,7 @@ class DeviceMesh:
assert
torch
.
equal
(
torch
.
unique
(
self
.
_physical_mesh_id
),
torch
.
unique
(
self
.
logical_mesh_id
)),
\
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert
torch
.
unique
(
self
.
_physical_mesh_id
).
numel
()
==
self
.
_physical_mesh_id
.
numel
(),
\
"Found duplicate IDs in the phy
i
scal_mesh_id and this is not allowed, please check your physical_mesh_id again."
"Found duplicate IDs in the phys
i
cal_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert
torch
.
unique
(
self
.
logical_mesh_id
).
numel
()
==
self
.
logical_mesh_id
.
numel
(),
\
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
...
...
@@ -118,7 +118,7 @@ class DeviceMesh:
self
.
_global_rank_of_current_process
=
None
self
.
_is_initialized
=
False
# attribute used to in
i
dicate whether this object
d
# attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
...
...
@@ -395,7 +395,7 @@ class DeviceMesh:
Example:
```python
s
physical_mesh_id = torch.arange(0, 16)
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
...
...
@@ -438,7 +438,7 @@ class DeviceMesh:
# the _local_rank refers to the local rank of the current process
for
_local_rank
in
range
(
self
.
logical_mesh_id
.
shape
[
dim
]):
# if this dimension is not init
a
ilized yet,
# if this dimension is not initi
a
lized yet,
# initialize it with an empty array
if
dim
not
in
processes_in_the_same_process_group
:
processes_in_the_same_process_group
[
dim
]
=
[]
...
...
@@ -447,7 +447,7 @@ class DeviceMesh:
process_coordinates
=
self
.
_global_to_local_rank_mapping
[
global_rank
].
copy
()
# replace the local rank in the given dimension with the
# l
c
oal rank of the current process iterated
# lo
c
al rank of the current process iterated
process_coordinates
[
dim
]
=
_local_rank
processes_in_the_same_process_group
[
dim
].
append
(
process_coordinates
)
...
...
colossalai/tensor/d_tensor/comm_spec.py
View file @
9c2feb2f
...
...
@@ -28,7 +28,7 @@ class CommSpec:
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
comm_pattern(CollectiveCommPattern): de
s
cribe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
...
...
colossalai/tensor/shape_consistency.py
View file @
9c2feb2f
...
...
@@ -339,7 +339,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
RS01 -> RR
'''
valid_spec_dict
=
{}
comm_pat
h
ern
=
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
comm_pat
t
ern
=
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
tensor_dims
=
len
(
source_spec
.
entire_shape
)
for
f_index
in
range
(
tensor_dims
-
1
):
for
b_index
in
range
(
f_index
+
1
,
tensor_dims
):
...
...
@@ -362,7 +362,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
b_target_pair
=
(
b_index
,
[])
gather_dim
,
logical_process_axes
=
mix_gather_simulator
(
f_target_pair
,
b_target_pair
)
comm_spec
=
CommSpec
(
comm_pat
h
ern
,
comm_spec
=
CommSpec
(
comm_pat
t
ern
,
sharding_spec
=
source_spec
,
gather_dim
=
gather_dim
,
logical_process_axis
=
logical_process_axes
,
...
...
tests/kit/model_zoo/transformers/t5.py
View file @
9c2feb2f
...
...
@@ -43,7 +43,7 @@ def data_gen_for_t5_model():
# output transform function
output_transform_fn
=
lambda
x
:
x
# define loss func
i
ton
# define loss funct
i
on
loss_fn_for_t5_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_encoder_only
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_conditional_generation
=
lambda
x
:
x
.
loss
...
...
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
View file @
9c2feb2f
...
...
@@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
model
=
DummyModel
()
criterion
=
lambda
x
:
x
.
mean
()
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
# create a custom da
se
tset with 0 to 10
# create a custom dat
a
set with 0 to 10
dataset
=
torch
.
arange
(
0
,
10
)
train_dataloader
=
plugin
.
prepare_dataloader
(
dataset
,
batch_size
=
2
)
model
,
optimizer
,
criterion
,
train_dataloader
,
_
=
booster
.
boost
(
model
,
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
9c2feb2f
...
...
@@ -15,7 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from
tests.kit.model_zoo
import
model_zoo
# test ba
i
sc fsdp function
# test bas
i
c fsdp function
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment