Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
83289173
Commit
83289173
authored
Sep 08, 2022
by
Super Daniel
Committed by
Frank Lee
Sep 08, 2022
Browse files
[NFC] polish colossalai/testing/comparison.py code style. (#1558)
parent
7cc052f6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
colossalai/testing/comparison.py
colossalai/testing/comparison.py
+5
-1
No files found.
colossalai/testing/comparison.py
View file @
83289173
...
@@ -7,15 +7,19 @@ from torch.distributed import ProcessGroup
...
@@ -7,15 +7,19 @@ from torch.distributed import ProcessGroup
def
assert_equal
(
a
:
Tensor
,
b
:
Tensor
):
def
assert_equal
(
a
:
Tensor
,
b
:
Tensor
):
assert
torch
.
all
(
a
==
b
),
f
'expected a and b to be equal but they are not,
{
a
}
vs
{
b
}
'
assert
torch
.
all
(
a
==
b
),
f
'expected a and b to be equal but they are not,
{
a
}
vs
{
b
}
'
def
assert_not_equal
(
a
:
Tensor
,
b
:
Tensor
):
def
assert_not_equal
(
a
:
Tensor
,
b
:
Tensor
):
assert
not
torch
.
all
(
a
==
b
),
f
'expected a and b to be not equal but they are,
{
a
}
vs
{
b
}
'
assert
not
torch
.
all
(
a
==
b
),
f
'expected a and b to be not equal but they are,
{
a
}
vs
{
b
}
'
def
assert_close
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-5
,
atol
:
float
=
1e-8
):
def
assert_close
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-5
,
atol
:
float
=
1e-8
):
assert
torch
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
),
f
'expected a and b to be close but they are not,
{
a
}
vs
{
b
}
'
assert
torch
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
),
f
'expected a and b to be close but they are not,
{
a
}
vs
{
b
}
'
def
assert_close_loose
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-3
,
atol
:
float
=
1e-3
):
def
assert_close_loose
(
a
:
Tensor
,
b
:
Tensor
,
rtol
:
float
=
1e-3
,
atol
:
float
=
1e-3
):
assert_close
(
a
,
b
,
rtol
,
atol
)
assert_close
(
a
,
b
,
rtol
,
atol
)
def
assert_equal_in_group
(
tensor
:
Tensor
,
process_group
:
ProcessGroup
=
None
):
def
assert_equal_in_group
(
tensor
:
Tensor
,
process_group
:
ProcessGroup
=
None
):
# all gather tensors from different ranks
# all gather tensors from different ranks
world_size
=
dist
.
get_world_size
(
process_group
)
world_size
=
dist
.
get_world_size
(
process_group
)
...
@@ -25,5 +29,5 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
...
@@ -25,5 +29,5 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
# check if they are equal one by one
# check if they are equal one by one
for
i
in
range
(
world_size
-
1
):
for
i
in
range
(
world_size
-
1
):
a
=
tensor_list
[
i
]
a
=
tensor_list
[
i
]
b
=
tensor_list
[
i
+
1
]
b
=
tensor_list
[
i
+
1
]
assert
torch
.
all
(
a
==
b
),
f
'expected tensors on rank
{
i
}
and
{
i
+
1
}
to be equal but they are not,
{
a
}
vs
{
b
}
'
assert
torch
.
all
(
a
==
b
),
f
'expected tensors on rank
{
i
}
and
{
i
+
1
}
to be equal but they are not,
{
a
}
vs
{
b
}
'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment