Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9427a32
Unverified
Commit
c9427a32
authored
Aug 11, 2022
by
Jiarui Fang
Committed by
GitHub
Aug 11, 2022
Browse files
hotfix #1434 (#1437)
parent
039b7ed3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
.../nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
+5
-3
No files found.
colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py
View file @
c9427a32
...
@@ -7,7 +7,7 @@ from .cache_mgr import CachedParamMgr
...
@@ -7,7 +7,7 @@ from .cache_mgr import CachedParamMgr
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
.._utils
import
dual_all_to_all
from
.._utils
import
dual_all_to_all
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputeSpec
,
ComputePattern
,
ProcessGroup
from
colossalai.tensor
import
ColoParameter
,
ShardSpec
,
ComputeSpec
,
ComputePattern
,
ProcessGroup
,
ColoTensorSpec
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
def
get_partition
(
embedding_dim
,
rank
,
world_size
)
->
Tuple
[
int
,
int
,
bool
]:
...
@@ -57,13 +57,15 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
...
@@ -57,13 +57,15 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
self
.
embedding_dim_per_partition
=
self
.
partition_end_index
-
self
.
partition_start_index
self
.
embedding_dim_per_partition
=
self
.
partition_end_index
-
self
.
partition_start_index
if
_weight
is
None
:
if
_weight
is
None
:
self
.
_weight
.
process_group
=
ProcessGroup
(
tp_degree
=
self
.
world_size
)
colo_tensor_spec
=
ColoTensorSpec
(
pg
=
ProcessGroup
(
tp_degree
=
self
.
world_size
),
dist_attr
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
]),
compute_attr
=
ComputePattern
.
TP1D
)
self
.
_weight
=
ColoParameter
.
from_torch_tensor
(
torch
.
empty
(
self
.
num_embeddings
,
self
.
_weight
=
ColoParameter
.
from_torch_tensor
(
torch
.
empty
(
self
.
num_embeddings
,
self
.
embedding_dim_per_partition
,
self
.
embedding_dim_per_partition
,
device
=
'cpu'
,
device
=
'cpu'
,
dtype
=
dtype
),
dtype
=
dtype
),
requires_grad
=
True
,
requires_grad
=
True
,
spec
=
ShardSpec
(
dims
=
[
-
1
],
num_partitions
=
[
self
.
world_size
])
)
spec
=
colo_tensor_spec
)
self
.
init_parameters
()
self
.
init_parameters
()
else
:
else
:
assert
isinstance
(
_weight
,
ColoParameter
),
"initialized weight must in type of ColoParameter"
assert
isinstance
(
_weight
,
ColoParameter
),
"initialized weight must in type of ColoParameter"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment