Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d7f8db8e
Unverified
Commit
d7f8db8e
authored
Jan 22, 2024
by
Hongxin Liu
Committed by
GitHub
Jan 22, 2024
Browse files
[hotfix] fix 3d plugin test (#5292)
parent
d66e6988
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
tests/test_booster/test_plugin/test_3d_plugin.py
tests/test_booster/test_plugin/test_3d_plugin.py
+5
-2
No files found.
tests/test_booster/test_plugin/test_3d_plugin.py
View file @
d7f8db8e
...
@@ -8,13 +8,14 @@ from torch.testing import assert_close
...
@@ -8,13 +8,14 @@ from torch.testing import assert_close
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
import
colossalai
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
HybridParallelPlugin
from
colossalai.booster.plugin
import
HybridParallelPlugin
from
colossalai.fx
import
is_compatible_with_meta
from
colossalai.fx
import
is_compatible_with_meta
from
colossalai.lazy.lazy_init
import
LazyInitContext
from
colossalai.lazy.lazy_init
import
LazyInitContext
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
,
set_seed
from
colossalai.utils
import
set_seed
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
...
@@ -23,7 +24,9 @@ class RandomDataset(Dataset):
...
@@ -23,7 +24,9 @@ class RandomDataset(Dataset):
self
.
num_samples
=
num_samples
self
.
num_samples
=
num_samples
self
.
max_length
=
max_length
self
.
max_length
=
max_length
set_seed
(
42
)
set_seed
(
42
)
self
.
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
num_samples
,
max_length
),
device
=
get_current_device
())
self
.
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
num_samples
,
max_length
),
device
=
get_accelerator
().
get_current_device
()
)
self
.
attention_mask
=
torch
.
ones_like
(
self
.
input_ids
)
self
.
attention_mask
=
torch
.
ones_like
(
self
.
input_ids
)
def
__len__
(
self
):
def
__len__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment