Unverified Commit 10b3df65 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[FAW] move coloparam setting in test code. (#1429)

parent cb98cf55
...@@ -67,9 +67,6 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): ...@@ -67,9 +67,6 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
self.init_parameters() self.init_parameters()
else: else:
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter" assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
_weight.process_group = ProcessGroup(tp_degree=self.world_size)
_weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]),
ComputeSpec(ComputePattern.TP1D))
self._weight = _weight self._weight = _weight
@property @property
......
...@@ -8,11 +8,9 @@ import random ...@@ -8,11 +8,9 @@ import random
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 10, 8 NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8
...@@ -161,6 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size): ...@@ -161,6 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
weight = torch.rand(num_embed, embed_dim) weight = torch.rand(num_embed, embed_dim)
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False) coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
# initialize the tensor spec for the embedding weight parameter,
# which is an ColoParameter.
coloweight.process_group = ProcessGroup(tp_degree=world_size)
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
include_last_offset=True, include_last_offset=True,
freeze=False, freeze=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment