Unverified Commit 039b7ed3 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[polish] add update directory in gemini; rename AgChunk to ChunkV2 (#1432)

parent f20cb4e8
from .chunkv2 import ChunkV2
...@@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF ...@@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
free_storage, alloc_storage free_storage, alloc_storage
class AgChunk: class ChunkV2:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
process_group: ColoProcessGroup, process_group: ColoProcessGroup,
......
...@@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device ...@@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.ag_chunk import AgChunk from colossalai.gemini.update import ChunkV2
def dist_sum(x): def dist_sum(x):
...@@ -38,7 +38,7 @@ def check_euqal(param, param_cp): ...@@ -38,7 +38,7 @@ def check_euqal(param, param_cp):
def exam_chunk_basic(init_device, keep_gathered, pin_memory): def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup() pg = ColoProcessGroup()
my_chunk = AgChunk( my_chunk = ChunkV2(
chunk_size=1024, chunk_size=1024,
process_group=pg, process_group=pg,
dtype=torch.float32, dtype=torch.float32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment