Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
11bddb6e
Commit
11bddb6e
authored
Mar 08, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] update zero context init with the updated test utils (#327)
parent
6268446b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
97 additions
and
50 deletions
+97
-50
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+8
-8
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+19
-10
colossalai/zero/shard_utils/base_shard_strategy.py
colossalai/zero/shard_utils/base_shard_strategy.py
+5
-6
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+14
-2
tests/components_to_test/nested_model.py
tests/components_to_test/nested_model.py
+13
-6
tests/components_to_test/repeated_computed_layer.py
tests/components_to_test/repeated_computed_layer.py
+9
-3
tests/components_to_test/resnet.py
tests/components_to_test/resnet.py
+9
-3
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+3
-2
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+15
-10
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+2
-0
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
11bddb6e
from
re
import
S
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
import
torch
from
.
import
BaseOpHook
from
.
import
BaseOpHook
...
@@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
...
@@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
time
import
sleep
,
time
from
time
import
sleep
,
time
import
pickle
import
pickle
from
typing
import
Union
,
Optional
from
typing
import
Optional
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
...
@@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
...
@@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
"""
"""
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
return
ret
class
AsyncMemoryMonitor
:
class
AsyncMemoryMonitor
:
def
__init__
(
self
,
power
=
10
):
def
__init__
(
self
,
power
=
10
):
"""
"""
An Async Mem Monitor runing during computing.
An Async Mem Monitor runing during computing.
...
@@ -81,7 +81,7 @@ class AsyncMemoryMonitor:
...
@@ -81,7 +81,7 @@ class AsyncMemoryMonitor:
def
save
(
self
,
filename
):
def
save
(
self
,
filename
):
with
open
(
filename
,
"wb"
)
as
f
:
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
self
.
state_dict
(),
f
)
pickle
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
self
.
time_stamps
.
clear
()
...
@@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook):
'''
'''
Collect GPU memory usage information
Collect GPU memory usage information
Args:
Args:
warmup (int): This parameter indicates how many iterations to truncate
warmup (int): This parameter indicates how many iterations to truncate
before profiling, e.g. set to 5 and the data will start from 6-th iteration
before profiling, e.g. set to 5 and the data will start from 6-th iteration
refreshrate (int): This parameter decides the frequency of write file.
refreshrate (int): This parameter decides the frequency of write file.
...
@@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
_data_prefix (string): the prefix of the stats data file
_data_prefix (string): the prefix of the stats data file
_rank (int): the rank of current node
_rank (int): the rank of current node
'''
'''
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
super
().
__init__
()
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
...
@@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook):
@
property
@
property
def
refreshrate
(
self
)
->
int
:
def
refreshrate
(
self
)
->
int
:
return
self
.
_refreshrate
return
self
.
_refreshrate
@
property
@
property
def
warmup
(
self
)
->
int
:
def
warmup
(
self
)
->
int
:
return
self
.
_warmup
return
self
.
_warmup
...
@@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
# every `refreshrate` times, refresh the file
# every `refreshrate` times, refresh the file
if
self
.
valid_iter
!=
0
and
self
.
valid_iter
%
self
.
refreshrate
==
0
:
if
self
.
valid_iter
!=
0
and
self
.
valid_iter
%
self
.
refreshrate
==
0
:
# output file info
# output file info
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
'dump a memory statistics as pickle to
{
self
.
_dataprefix
}
-
{
self
.
_rank
}
.pkl'
)
f
'dump a memory statistics as pickle to
{
self
.
_dataprefix
}
-
{
self
.
_rank
}
.pkl'
)
self
.
save_results
()
self
.
save_results
()
self
.
_count
+=
1
self
.
_count
+=
1
self
.
_logger
.
debug
(
f
'data file has been refreshed
{
self
.
_count
}
times'
)
self
.
_logger
.
debug
(
f
'data file has been refreshed
{
self
.
_count
}
times'
)
...
...
colossalai/zero/init_ctx/init_context.py
View file @
11bddb6e
...
@@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
3. Shard the param and grad according to flags.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
convert_fp16
:
bool
,
convert_fp16
:
bool
,
convert_cuda
:
bool
,
convert_cuda
:
bool
,
shard_strategy
:
BaseShardStrategy
,
shard_strategy
:
BaseShardStrategy
,
shard_param
:
bool
=
False
,
shard_param
:
bool
=
False
,
shard_grad
:
bool
=
False
,
shard_grad
:
bool
=
False
,
rm_torch_payload_on_the_fly
=
False
):
):
super
().
__init__
()
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
convert_fp16
=
convert_fp16
self
.
convert_cuda
=
convert_cuda
self
.
convert_cuda
=
convert_cuda
self
.
shard_param
=
shard_param
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
def
_post_context_exec
(
self
):
def
_post_context_exec
(
self
):
"""The callback function when the context exits.
"""The callback function when the context exits.
"""
"""
pass
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
def
_post_init_method
(
self
,
module
):
def
_post_init_method
(
self
,
module
):
r
"""The function to call at the end of the constructor of each nn.Module.
r
"""The function to call at the end of the constructor of each nn.Module.
...
@@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
ca_attr
=
ShardedParamV2
(
param
)
param
.
ca_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
ca_attr
.
_data_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
ca_attr
.
_data_sharded_tensor
])
if
param
.
ca_attr
.
grad
and
self
.
shard_grad
:
if
param
.
ca_attr
.
grad
and
self
.
shard_grad
:
...
...
colossalai/zero/shard_utils/base_shard_strategy.py
View file @
11bddb6e
...
@@ -7,6 +7,11 @@ from typing import List, Optional
...
@@ -7,6 +7,11 @@ from typing import List, Optional
class
BaseShardStrategy
(
ABC
):
class
BaseShardStrategy
(
ABC
):
def
__init__
(
self
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
def
__init__
(
self
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
"""
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
...
@@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
...
@@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
@
abstractmethod
@
abstractmethod
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
sharded the memory of tensor on multiple processes.
"""
pass
pass
@
abstractmethod
@
abstractmethod
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
duplicate tensor payload on each processes.
"""
pass
pass
colossalai/zero/sharded_param/sharded_param.py
View file @
11bddb6e
...
@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
...
@@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
class
ShardedParamV2
(
object
):
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
...
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
...
@@ -19,7 +22,16 @@ class ShardedParamV2(object):
self
.
_grad_sharded_tensor
=
None
self
.
_grad_sharded_tensor
=
None
# make sure the shared param is the only owner of payload
# make sure the shared param is the only owner of payload
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
# The param.data maybe used to init the other part of the model.
# For example: File "resnet.py", line 190, in __init__
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self
.
param
=
param
if
rm_torch_payload
:
self
.
remove_torch_payload
()
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
@
property
@
property
def
data
(
self
):
def
data
(
self
):
...
...
tests/components_to_test/nested_model.py
View file @
11bddb6e
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.nn
import
CheckpointModule
from
.utils
import
DummyDataGenerator
from
.utils
import
DummyDataGenerator
from
.registry
import
non_distributed_component_funcs
from
.registry
import
non_distributed_component_funcs
...
@@ -15,10 +16,10 @@ class SubNet(nn.Module):
...
@@ -15,10 +16,10 @@ class SubNet(nn.Module):
return
F
.
linear
(
x
,
weight
,
self
.
bias
)
return
F
.
linear
(
x
,
weight
,
self
.
bias
)
class
NestedNet
(
nn
.
Module
):
class
NestedNet
(
Checkpoint
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
(
checkpoint
)
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
self
.
sub_fc
=
SubNet
(
5
)
self
.
sub_fc
=
SubNet
(
5
)
self
.
fc2
=
nn
.
Linear
(
5
,
2
)
self
.
fc2
=
nn
.
Linear
(
5
,
2
)
...
@@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
...
@@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
@
non_distributed_component_funcs
.
register
(
name
=
'nested_model'
)
@
non_distributed_component_funcs
.
register
(
name
=
'nested_model'
)
def
get_training_components
():
def
get_training_components
():
model
=
NestedNet
()
def
model_builder
(
checkpoint
):
return
NestedNet
(
checkpoint
)
trainloader
=
DummyDataLoader
()
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/components_to_test/repeated_computed_layer.py
View file @
11bddb6e
...
@@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
...
@@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
@
non_distributed_component_funcs
.
register
(
name
=
'repeated_computed_layers'
)
@
non_distributed_component_funcs
.
register
(
name
=
'repeated_computed_layers'
)
def
get_training_components
():
def
get_training_components
():
model
=
NetWithRepeatedlyComputedLayers
(
checkpoint
=
True
)
def
model_builder
(
checkpoint
=
True
):
return
NetWithRepeatedlyComputedLayers
(
checkpoint
)
trainloader
=
DummyDataLoader
()
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/components_to_test/resnet.py
View file @
11bddb6e
...
@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
...
@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
@
non_distributed_component_funcs
.
register
(
name
=
'resnet18'
)
@
non_distributed_component_funcs
.
register
(
name
=
'resnet18'
)
def
get_resnet_training_components
():
def
get_resnet_training_components
():
model
=
resnet18
(
num_classes
=
10
)
def
model_builder
(
checkpoint
=
False
):
return
resnet18
(
num_classes
=
10
)
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model
,
trainloader
,
testloader
,
optim
,
criterion
return
model
_builder
,
trainloader
,
testloader
,
optim
_builder
,
criterion
tests/test_engine/test_engine.py
View file @
11bddb6e
...
@@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
...
@@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def
run_train
():
def
run_train
():
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
model
,
train_dataloader
,
_
,
optimizer
,
criterion
=
get_components_func
()
model
_builder
,
train_dataloader
,
_
,
optimizer
_builder
,
criterion
=
get_components_func
()
model
=
model_builder
(
checkpoint
=
False
)
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
_builder
(
model
)
,
criterion
=
criterion
,
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
train_dataloader
=
train_dataloader
)
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
11bddb6e
...
@@ -9,22 +9,27 @@ import torch
...
@@ -9,22 +9,27 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
common
import
CONFIG
,
Net
from
common
import
CONFIG
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
TensorShardStrategy
(),
shard_param
=
True
):
for
get_components_func
in
non_distributed_component_funcs
:
# Note Net(checkpoint=True).cuda() moving to cuda is useless
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model
=
Net
(
checkpoint
=
True
)
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
for
param
in
model
.
parameters
():
shard_strategy
=
TensorShardStrategy
(),
assert
hasattr
(
param
,
'ca_attr'
)
shard_param
=
True
):
assert
param
.
ca_attr
.
data
.
dtype
==
torch
.
half
model
=
model_builder
(
checkpoint
=
True
)
assert
param
.
ca_attr
.
_data_sharded_tensor
.
is_sharded
assert
param
.
ca_attr
.
data
.
device
.
type
==
'cuda'
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
param
.
ca_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
ca_attr
.
_data_sharded_tensor
.
is_sharded
assert
param
.
ca_attr
.
data
.
device
.
type
==
'cuda'
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
11bddb6e
...
@@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
assert
(
param
.
data
.
numel
()
==
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment