Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f7e276fa
Unverified
Commit
f7e276fa
authored
Nov 16, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 16, 2022
Browse files
[Gemini] add GeminiAdamOptimizer (#1960)
parent
7066dfbf
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
66 additions
and
44 deletions
+66
-44
colossalai/nn/optimizer/gemini_optimizer.py
colossalai/nn/optimizer/gemini_optimizer.py
+15
-0
colossalai/nn/optimizer/hybrid_adam.py
colossalai/nn/optimizer/hybrid_adam.py
+9
-6
colossalai/nn/optimizer/zero_optimizer.py
colossalai/nn/optimizer/zero_optimizer.py
+3
-2
colossalai/nn/parallel/gemini_parallel.py
colossalai/nn/parallel/gemini_parallel.py
+2
-1
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+1
-1
examples/language/gpt/README.md
examples/language/gpt/README.md
+2
-4
examples/language/gpt/train_gpt_demo.py
examples/language/gpt/train_gpt_demo.py
+6
-4
examples/language/opt/run_clm.py
examples/language/opt/run_clm.py
+1
-1
examples/tutorial/opt/opt/run_clm.py
examples/tutorial/opt/opt/run_clm.py
+12
-12
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+1
-1
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+1
-1
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+13
-11
No files found.
colossalai/nn/optimizer/gemini_optimizer.py
0 → 100644
View file @
f7e276fa
from
typing
import
Any
import
torch
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
__all__
=
[
'GeminiAdamOptimizer'
]
class
GeminiAdamOptimizer
(
ZeroOptimizer
):
def
__init__
(
self
,
model
:
torch
.
nn
.
Module
,
**
defaults
:
Any
)
->
None
:
optimizer
=
HybridAdam
(
model
.
parameters
(),
**
defaults
)
super
().
__init__
(
optimizer
,
model
,
**
defaults
)
colossalai/nn/optimizer/hybrid_adam.py
View file @
f7e276fa
from
typing
import
Any
,
Optional
import
torch
import
torch
from
colossalai.utils
import
multi_tensor_applier
from
colossalai.registry
import
OPTIMIZERS
from
colossalai.registry
import
OPTIMIZERS
from
typing
import
Optional
from
colossalai.utils
import
multi_tensor_applier
from
.nvme_optimizer
import
NVMeOptimizer
from
.nvme_optimizer
import
NVMeOptimizer
...
@@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer):
...
@@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer):
weight_decay
=
0
,
weight_decay
=
0
,
adamw_mode
=
True
,
adamw_mode
=
True
,
nvme_offload_fraction
:
float
=
0.0
,
nvme_offload_fraction
:
float
=
0.0
,
nvme_offload_dir
:
Optional
[
str
]
=
None
):
nvme_offload_dir
:
Optional
[
str
]
=
None
,
**
defaults
:
Any
):
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
super
(
HybridAdam
,
self
).
__init__
(
model_params
,
default_args
,
nvme_offload_fraction
,
nvme_offload_dir
)
super
(
HybridAdam
,
self
).
__init__
(
model_params
,
default_args
,
nvme_offload_fraction
,
nvme_offload_dir
)
self
.
adamw_mode
=
adamw_mode
self
.
adamw_mode
=
adamw_mode
try
:
try
:
import
cpu_adam
import
colossal_C
import
colossal_C
import
cpu_adam
except
ImportError
:
except
ImportError
:
raise
ImportError
(
'Please install colossalai from source code to use HybridAdam'
)
raise
ImportError
(
'Please install colossalai from source code to use HybridAdam'
)
...
...
colossalai/zer
o
/zero_optimizer.py
→
colossalai/
nn/optimi
zer/zero_optimizer.py
View file @
f7e276fa
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
Set
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
backoff_factor
:
float
=
0.5
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
growth_interval
:
int
=
1000
,
hysteresis
:
int
=
2
,
hysteresis
:
int
=
2
,
max_scale
:
float
=
2
**
32
):
max_scale
:
float
=
2
**
32
,
**
defaults
:
Any
):
super
().
__init__
(
optim
)
super
().
__init__
(
optim
)
assert
isinstance
(
module
,
ZeroDDP
)
assert
isinstance
(
module
,
ZeroDDP
)
self
.
module
=
module
self
.
module
=
module
...
...
colossalai/nn/parallel/gemini_parallel.py
View file @
f7e276fa
...
@@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP):
...
@@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
search_range_mb
:
int
=
32
)
->
None
:
search_range_mb
:
int
=
32
)
->
None
:
"""
"""
A torch.Module warpper using ZeROD
P
P and Genimi.
A torch.Module warpper using ZeRO
-
DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
Example:
model is initialized under the context of ColoInitContext
model is initialized under the context of ColoInitContext
...
...
colossalai/zero/__init__.py
View file @
f7e276fa
...
@@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger
...
@@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
LowLevelZeroOptimizer
,
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
LowLevelZeroOptimizer
,
ShardedOptimizerV2
from
.zero_optimizer
import
ZeroOptimizer
from
..nn.optimizer
.zero_optimizer
import
ZeroOptimizer
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
...
...
examples/language/gpt/README.md
View file @
f7e276fa
...
@@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
...
@@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
## GPT
## GPT
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
The
`train_gpt_demo.py`
provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
## Our Modifications
The ColossalAI leverages Tensor Parallel and Gemini.
The
`train_gpt_demo.py`
provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
## Quick Start
## Quick Start
You can launch training by using the following bash script.
You can launch training by using the following bash script.
...
...
examples/language/gpt/train_gpt_demo.py
View file @
f7e276fa
...
@@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
...
@@ -222,7 +223,7 @@ def main():
...
@@ -222,7 +223,7 @@ def main():
default_dist_spec
=
ShardSpec
([
-
1
],
[
args
.
tp_degree
])
if
args
.
shardinit
else
None
default_dist_spec
=
ShardSpec
([
-
1
],
[
args
.
tp_degree
])
if
args
.
shardinit
else
None
# build GPT model
# build GPT model
with
ColoInitContext
(
device
=
'cu
da
'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
):
with
ColoInitContext
(
device
=
'c
p
u'
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
):
model
=
gpt2_medium
(
checkpoint
=
True
)
model
=
gpt2_medium
(
checkpoint
=
True
)
pg
=
default_pg
pg
=
default_pg
...
@@ -232,8 +233,9 @@ def main():
...
@@ -232,8 +233,9 @@ def main():
model
=
gemini_zero_dpp
(
model
,
pg
,
args
.
placement
)
model
=
gemini_zero_dpp
(
model
,
pg
,
args
.
placement
)
# build optimizer
# build optimizer
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
2
**
5
)
optimizer
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
**
5
)
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
elif
args
.
distplan
==
"ddp"
:
elif
args
.
distplan
==
"ddp"
:
...
...
examples/language/opt/run_clm.py
View file @
f7e276fa
...
@@ -43,11 +43,11 @@ from colossalai.context import ParallelMode
...
@@ -43,11 +43,11 @@ from colossalai.context import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
,
get_dataloader
from
colossalai.utils
import
get_current_device
,
get_dataloader
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
transformers
import
(
from
transformers
import
(
CONFIG_MAPPING
,
CONFIG_MAPPING
,
MODEL_MAPPING
,
MODEL_MAPPING
,
...
...
examples/tutorial/opt/opt/run_clm.py
View file @
f7e276fa
...
@@ -30,13 +30,24 @@ from itertools import chain
...
@@ -30,13 +30,24 @@ from itertools import chain
import
datasets
import
datasets
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
transformers
from
accelerate.utils
import
set_seed
from
accelerate.utils
import
set_seed
from
context
import
barrier_context
from
context
import
barrier_context
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
packaging
import
version
from
packaging
import
version
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
import
colossalai
import
transformers
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
,
get_dataloader
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
transformers
import
(
from
transformers
import
(
CONFIG_MAPPING
,
CONFIG_MAPPING
,
MODEL_MAPPING
,
MODEL_MAPPING
,
...
@@ -50,17 +61,6 @@ from transformers import (
...
@@ -50,17 +61,6 @@ from transformers import (
)
)
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
import
colossalai
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
,
get_dataloader
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
require_version
(
"datasets>=1.8.0"
,
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt"
)
require_version
(
"datasets>=1.8.0"
,
"To fix: pip install -r examples/pytorch/language-modeling/requirements.txt"
)
MODEL_CONFIG_CLASSES
=
list
(
MODEL_MAPPING
.
keys
())
MODEL_CONFIG_CLASSES
=
list
(
MODEL_MAPPING
.
keys
())
...
...
tests/test_gemini/update/test_optim.py
View file @
f7e276fa
...
@@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp
...
@@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp
from
colossalai.gemini.chunk
import
ChunkManager
,
init_chunk_manager
,
search_chunk_configuration
from
colossalai.gemini.chunk
import
ChunkManager
,
init_chunk_manager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
,
tensor_equal
,
tensor_shard_equal
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
,
tensor_equal
,
tensor_shard_equal
...
...
tests/test_gemini/update/test_zerooptim_state_dict.py
View file @
f7e276fa
...
@@ -9,12 +9,12 @@ import colossalai
...
@@ -9,12 +9,12 @@ import colossalai
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
f7e276fa
...
@@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import
colossalai
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.chunk
import
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
GeminiDDP
,
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_shard_equal
from
tests.test_tensor.common_utils
import
set_seed
,
tensor_shard_equal
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
...
@@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
...
@@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
init_device
=
torch
.
device
(
'cpu'
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
else
:
init_device
=
None
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
model
=
GeminiDDP
(
model
,
init_device
,
placement_policy
,
True
,
False
,
32
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
1
)
# The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
zero_optim
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
1
)
# The same as the following 2 lines
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
print
(
chunk_manager
)
check_param
(
model
,
torch_model
,
pg
)
check_param
(
model
,
torch_model
,
pg
)
model
.
eval
()
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment