Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
10b3df65
Unverified
Commit
10b3df65
authored
Aug 10, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 10, 2022
Browse files
[FAW] move coloparam setting in test code. (#1429)
parent
cb98cf55
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
.../nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
+0
-3
tests/test_tensor/ops/test_cache_embedding.py
tests/test_tensor/ops/test_cache_embedding.py
+6
-3
No files found.
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
View file @
10b3df65
...
...
@@ -67,9 +67,6 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
self
.
init_parameters
()
else
:
assert
isinstance
(
_weight
,
ColoParameter
),
"initialized weight must in type of ColoParameter"
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
_weight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
self
.
_weight
=
_weight
@
property
...
...
tests/test_tensor/ops/test_cache_embedding.py
View file @
10b3df65
...
...
@@ -8,11 +8,9 @@ import random
import
colossalai
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ColoParameter
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
,
ShardSpec
,
ComputePattern
,
ComputeSpec
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
,
ParallelFreqAwareEmbeddingBag
from
colossalai.nn._ops.cache_embedding
import
CachedParamMgr
,
FreqAwareEmbeddingBag
NUM_EMBED
,
EMBED_DIM
=
10
,
8
BATCH_SIZE
=
8
...
...
@@ -161,6 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
weight
=
torch
.
rand
(
num_embed
,
embed_dim
)
coloweight
=
ColoParameter
(
weight
.
clone
().
detach
().
cpu
(),
requires_grad
=
False
)
# initialize the tensor spec for the embedding weight parameter,
# which is an ColoParameter.
coloweight
.
process_group
=
ProcessGroup
(
tp_degree
=
world_size
)
coloweight
.
set_tensor_spec
(
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
world_size
]),
ComputeSpec
(
ComputePattern
.
TP1D
))
model
=
ParallelFreqAwareEmbeddingBag
.
from_pretrained
(
coloweight
,
include_last_offset
=
True
,
freeze
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment