Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
88759e28
Unverified
Commit
88759e28
authored
Apr 19, 2022
by
HELSON
Committed by
GitHub
Apr 19, 2022
Browse files
[zero] add ZeroTensorShardStrategy (#793)
parent
681addb5
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
214 additions
and
11 deletions
+214
-11
colossalai/kernel/cuda_native/csrc/zero_comm.cpp
colossalai/kernel/cuda_native/csrc/zero_comm.cpp
+109
-0
colossalai/zero/comm/__init__.py
colossalai/zero/comm/__init__.py
+1
-0
colossalai/zero/comm/zero_comm.py
colossalai/zero/comm/zero_comm.py
+46
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+2
-0
colossalai/zero/shard_utils/__init__.py
colossalai/zero/shard_utils/__init__.py
+2
-1
colossalai/zero/shard_utils/zero_tensor_shard_strategy.py
colossalai/zero/shard_utils/zero_tensor_shard_strategy.py
+38
-0
setup.py
setup.py
+6
-0
tests/test_zero/test_found_inf.py
tests/test_zero/test_found_inf.py
+2
-2
tests/test_zero/test_init_context.py
tests/test_zero/test_init_context.py
+2
-2
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+2
-2
tests/test_zero/test_shard_model_v2.py
tests/test_zero/test_shard_model_v2.py
+2
-2
tests/test_zero/test_state_dict.py
tests/test_zero/test_state_dict.py
+2
-2
No files found.
colossalai/kernel/cuda_native/csrc/zero_comm.cpp
0 → 100644
View file @
88759e28
#include <cuda_runtime.h>
#include <nccl.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) \
TORCH_CHECK
(
x
.
device
().
is_cuda
(),
#
x
" must be a CUDA tensor"
)
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK
(
x
.
is_contiguous
(),
#
x
" must be contiguous"
)
#define CHECK_INPUT(x) \
CHECK_CUDA
(
x
);
\
CHECK_CONTIGUOUS
(
x
)
#define CUDACHECK(cmd) \
do
{
\
cudaError_t
e
=
cmd
;
\
if
(
e
!=
cudaSuccess
)
{
\
printf
(
"Failed: Cuda error %s:%d '%s'
\n
"
,
__FILE__
,
__LINE__
,
\
cudaGetErrorString
(
e
));
\
exit
(
EXIT_FAILURE
);
\
}
\
}
while
(
0
)
#define NCCLCHECK(cmd) \
do
{
\
ncclResult_t
r
=
cmd
;
\
if
(
r
!=
ncclSuccess
)
{
\
printf
(
"Failed, NCCL error %s:%d '%s'
\n
"
,
__FILE__
,
__LINE__
,
\
ncclGetErrorString
(
r
));
\
exit
(
EXIT_FAILURE
);
\
}
\
}
while
(
0
)
class
ZeroCommMgr
{
public:
cudaStream_t
cuda_stream
;
ncclComm_t
nccl_comm
;
ZeroCommMgr
(
const
ncclComm_t
&
comm_
)
{
CUDACHECK
(
cudaStreamCreate
(
&
cuda_stream
));
nccl_comm
=
comm_
;
}
};
ZeroCommMgr
*
GMGR
=
nullptr
;
#ifdef USE_C10D_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
}
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
"colossal_zero_comm"
,
rank
);
ncclComm_t
comm
;
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
));
return
comm
;
}
};
int
create_zero_comm
(
c10d
::
ProcessGroupNCCL
&
pg
,
at
::
Device
dev
)
{
auto
*
hack_group
=
reinterpret_cast
<
HackNCCLGroup
*>
(
&
pg
);
GMGR
=
new
ZeroCommMgr
(
hack_group
->
getcomm
(
dev
));
assert
(
GMGR
->
nccl_comm
!=
0
);
return
0
;
}
#endif
template
<
typename
scalar_t
>
void
colo_all_gather_impl
(
scalar_t
*
recvbuff
,
int
rank
,
int
sendcount
,
ncclDataType_t
data_type
)
{
scalar_t
*
sendbuff
=
recvbuff
+
(
rank
*
sendcount
);
NCCLCHECK
(
ncclAllGather
(
sendbuff
,
recvbuff
,
sendcount
,
data_type
,
GMGR
->
nccl_comm
,
GMGR
->
cuda_stream
));
CUDACHECK
(
cudaStreamSynchronize
(
GMGR
->
cuda_stream
));
}
int
colo_all_gather
(
torch
::
Tensor
&
input_tensor
,
int
rank
,
int
world_size
)
{
CHECK_INPUT
(
input_tensor
);
auto
total_size
=
input_tensor
.
numel
();
assert
(
total_size
%
world_size
==
0
);
auto
sendcount
=
total_size
/
world_size
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_tensor
.
scalar_type
(),
"colo_all_gather"
,
([
&
]
{
colo_all_gather_impl
<
scalar_t
>
(
input_tensor
.
data_ptr
<
scalar_t
>
(),
rank
,
sendcount
,
input_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
ncclHalf
:
ncclFloat
);
}));
return
0
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef USE_C10D_NCCL
m
.
def
(
"create_comm"
,
&
create_zero_comm
,
"Create the communication environment for Colossal Zero"
);
#endif
m
.
def
(
"inplace_all_gather"
,
&
colo_all_gather
,
"All gather operation used in Colossal Zero"
);
}
colossalai/zero/comm/__init__.py
0 → 100644
View file @
88759e28
from
.zero_comm
import
ZeroDist
colossalai/zero/comm/zero_comm.py
0 → 100644
View file @
88759e28
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.utils
import
get_current_device
from
typing
import
Optional
ZERO_USE_NCCL
=
False
try
:
import
colossal_zero_comm
ZERO_USE_NCCL
=
True
except
ImportError
:
print
(
"Please pip reinstall Colossalai."
)
class
ZeroCommWorld
(
metaclass
=
SingletonMeta
):
"""Zero communicator, used for communications in zero parallel.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
zero_pg
:
Optional
[
ProcessGroup
]
=
None
@
property
def
is_initialized
(
self
):
return
self
.
zero_pg
is
not
None
def
zero_comm_init
(
self
,
comm_group
:
ProcessGroup
):
if
not
ZERO_USE_NCCL
:
return
if
self
.
is_initialized
:
assert
self
.
zero_pg
==
comm_group
,
"Cant not initialize zero group twice"
return
self
.
zero_pg
=
comm_group
colossal_zero_comm
.
create_comm
(
self
.
zero_pg
,
get_current_device
())
def
zero_all_gather
(
self
,
input_tensor
:
torch
.
Tensor
):
assert
self
.
zero_pg
is
not
None
,
"Please initialize zero communication world first"
rank
=
dist
.
get_rank
(
self
.
zero_pg
)
world_size
=
self
.
zero_pg
.
size
()
colossal_zero_comm
.
inplace_all_gather
(
input_tensor
,
rank
,
world_size
)
ZeroDist
=
ZeroCommWorld
()
colossalai/zero/init_ctx/init_context.py
View file @
88759e28
...
@@ -12,6 +12,7 @@ from colossalai.logging import get_dist_logger
...
@@ -12,6 +12,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.comm
import
ZeroDist
from
contextlib
import
AbstractContextManager
from
contextlib
import
AbstractContextManager
...
@@ -191,6 +192,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -191,6 +192,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context
The Callback function when entering the context
"""
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
ZeroDist
.
zero_comm_init
(
self
.
dp_process_group
)
# initialize zero communication world
# substitute fan-in and fan-out calculation
# substitute fan-in and fan-out calculation
self
.
nn_fanin_fanout
=
nn
.
init
.
_calculate_fan_in_and_fan_out
self
.
nn_fanin_fanout
=
nn
.
init
.
_calculate_fan_in_and_fan_out
...
...
colossalai/zero/shard_utils/__init__.py
View file @
88759e28
from
.base_shard_strategy
import
BaseShardStrategy
from
.base_shard_strategy
import
BaseShardStrategy
from
.bucket_tensor_shard_strategy
import
BucketTensorShardStrategy
from
.bucket_tensor_shard_strategy
import
BucketTensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
from
.zero_tensor_shard_strategy
import
ZeroTensorShardStrategy
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
]
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
,
'ZeroTensorShardStrategy'
]
colossalai/zero/shard_utils/zero_tensor_shard_strategy.py
0 → 100644
View file @
88759e28
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.comm
import
ZeroDist
from
.tensor_shard_strategy
import
TensorShardStrategy
class
ZeroTensorShardStrategy
(
TensorShardStrategy
):
"""Use the same shard scheme as `TensorShardStrategy`'s.
But its all-gather operation is in-place, meaning that no extra buffer is created.
Extra buffer is created when using `torch.distributed.all_gather`.
This can reduce peak memory used in zero-offload.
You should notice that this strategy is highly coupled with zero.
You can not change its communication group and must use ZeroContext to create your model.
"""
def
_gather_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
if
not
t
.
is_sharded
:
return
target_device
=
t
.
device
payload_numel
=
t
.
payload
.
numel
()
world_size
=
dist
.
get_world_size
(
process_group
)
rank
=
dist
.
get_rank
(
process_group
)
buffer
=
torch
.
empty
(
payload_numel
*
world_size
,
dtype
=
t
.
payload
.
dtype
,
device
=
get_current_device
())
buffer_list
=
list
(
torch
.
chunk
(
buffer
,
chunks
=
world_size
,
dim
=
0
))
buffer_list
[
rank
].
copy_
(
t
.
payload
)
ZeroDist
.
zero_all_gather
(
buffer
)
# notice: process_group is useless here
gathered_payload
=
torch
.
narrow
(
buffer
,
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
colo_model_data_tensor_move_inline
(
t
,
target_device
)
t
.
is_sharded
=
False
setup.py
View file @
88759e28
...
@@ -134,6 +134,12 @@ if build_cuda_ext:
...
@@ -134,6 +134,12 @@ if build_cuda_ext:
'nvcc'
:
append_nvcc_threads
([
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
+
extra_cuda_flags
)
'nvcc'
:
append_nvcc_threads
([
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
+
extra_cuda_flags
)
})
})
ext_modules
.
append
(
cuda_ext_helper
(
name
=
'colossal_zero_comm'
,
sources
=
[
'zero_comm.cpp'
],
extra_cuda_flags
=
[
'-DUSE_C10D_NCCL'
],
extra_cxx_flags
=
[
'-DUSE_C10D_NCCL'
]))
ext_modules
.
append
(
ext_modules
.
append
(
cuda_ext_helper
(
'colossal_C'
,
[
cuda_ext_helper
(
'colossal_C'
,
[
'colossal_C_frontend.cpp'
,
'multi_tensor_sgd_kernel.cu'
,
'multi_tensor_scale_kernel.cu'
,
'colossal_C_frontend.cpp'
,
'multi_tensor_sgd_kernel.cu'
,
'multi_tensor_scale_kernel.cu'
,
...
...
tests/test_zero/test_found_inf.py
View file @
88759e28
...
@@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
...
@@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
Bucket
TensorShardStrategy
from
colossalai.zero.shard_utils
import
Zero
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
...
@@ -20,7 +20,7 @@ from common import CONFIG
...
@@ -20,7 +20,7 @@ from common import CONFIG
@
parameterize
(
"cpu_offload"
,
[
True
,
False
])
@
parameterize
(
"cpu_offload"
,
[
True
,
False
])
@
parameterize
(
"shard_strategy_class"
,
[
Bucket
TensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
Zero
TensorShardStrategy
])
@
parameterize
(
"gpu_margin_mem_ratio"
,
[
0.0
,
0.7
])
@
parameterize
(
"gpu_margin_mem_ratio"
,
[
0.0
,
0.7
])
def
_run_test_found_inf
(
cpu_offload
,
shard_strategy_class
,
gpu_margin_mem_ratio
):
def
_run_test_found_inf
(
cpu_offload
,
shard_strategy_class
,
gpu_margin_mem_ratio
):
test_models
=
[
'repeated_computed_layers'
]
test_models
=
[
'repeated_computed_layers'
]
...
...
tests/test_zero/test_init_context.py
View file @
88759e28
...
@@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
...
@@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
colo_model_mem_usage
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
,
ZeroTensorShardStrategy
)
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
common
import
CONFIG
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
,
ZeroTensorShardStrategy
])
def
run_model_test
(
init_device_type
,
shard_strategy_class
):
def
run_model_test
(
init_device_type
,
shard_strategy_class
):
logger
=
get_dist_logger
(
"test_zero_init"
)
logger
=
get_dist_logger
(
"test_zero_init"
)
...
...
tests/test_zero/test_mem_collector.py
View file @
88759e28
...
@@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
...
@@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.shard_utils
import
Bucket
TensorShardStrategy
from
colossalai.zero.shard_utils
import
Zero
TensorShardStrategy
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
functools
import
partial
from
functools
import
partial
...
@@ -35,7 +35,7 @@ def run_mem_collector_testing():
...
@@ -35,7 +35,7 @@ def run_mem_collector_testing():
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
# limit max memory to 50MB
# limit max memory to 50MB
colo_set_process_memory_fraction
(
fraction
)
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
Bucket
TensorShardStrategy
()
shard_strategy
=
Zero
TensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
MyTestModel
()
model
=
MyTestModel
()
...
...
tests/test_zero/test_shard_model_v2.py
View file @
88759e28
...
@@ -10,7 +10,7 @@ import torch.multiprocessing as mp
...
@@ -10,7 +10,7 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
Zero
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
...
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
...
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@
parameterize
(
"enable_autocast"
,
[
True
])
@
parameterize
(
"enable_autocast"
,
[
True
])
@
parameterize
(
"shard_strategy_class"
,
[
BucketTensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
ZeroTensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
enable_autocast
,
shard_strategy_class
):
def
run_model_test
(
enable_autocast
,
shard_strategy_class
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
,
'no_leaf_module'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
,
'no_leaf_module'
]
shard_strategy
=
shard_strategy_class
()
shard_strategy
=
shard_strategy_class
()
...
...
tests/test_zero/test_state_dict.py
View file @
88759e28
...
@@ -11,7 +11,7 @@ import torch.multiprocessing as mp
...
@@ -11,7 +11,7 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
,
ZeroTensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
common
import
CONFIG
from
common
import
CONFIG
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
,
ZeroTensorShardStrategy
])
def
run_zero_state_dict
(
shard_strategy_class
):
def
run_zero_state_dict
(
shard_strategy_class
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
shard_strategy
=
shard_strategy_class
()
shard_strategy
=
shard_strategy_class
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment