Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
93582629
Unverified
Commit
93582629
authored
Jan 12, 2023
by
Haofan Wang
Committed by
GitHub
Jan 12, 2023
Browse files
Fix False warning in initialize.py (#2456)
* Update initialize.py * pre-commit run check
parent
32c46e14
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
15 deletions
+14
-15
colossalai/initialize.py
colossalai/initialize.py
+14
-15
No files found.
colossalai/initialize.py
View file @
93582629
...
@@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
...
@@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.logging
import
get_dist_logger
from
colossalai.engine.schedule
import
NonPipelineSchedule
,
PipelineSchedule
,
InterleavedPipelineSchedule
,
get_tensor_shape
from
colossalai.engine
import
Engine
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.utils
import
(
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.amp
import
AMP_TYPE
,
convert_to_amp
from
colossalai.amp
import
AMP_TYPE
,
convert_to_amp
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.engine.gradient_accumulation
import
accumulate_gradient
from
colossalai.engine.gradient_accumulation
import
accumulate_gradient
from
colossalai.engine.schedule
import
(
InterleavedPipelineSchedule
,
NonPipelineSchedule
,
PipelineSchedule
,
get_tensor_shape
,
)
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
from
colossalai.utils.moe
import
sync_moe_model_param
from
colossalai.zero
import
convert_to_zero_v2
from
colossalai.zero
import
convert_to_zero_v2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
...
@@ -301,9 +300,9 @@ def initialize(model: nn.Module,
...
@@ -301,9 +300,9 @@ def initialize(model: nn.Module,
model
=
model
().
to
(
get_current_device
())
model
=
model
().
to
(
get_current_device
())
# optimizer maybe a optimizer_cls
# optimizer maybe a optimizer_cls
logger
.
warning
(
"Initializing an non ZeRO model with optimizer class"
)
if
isinstance
(
optimizer
,
Callable
):
if
isinstance
(
optimizer
,
Callable
):
optimizer
=
optimizer
(
model
.
parameters
())
optimizer
=
optimizer
(
model
.
parameters
())
logger
.
warning
(
"Initializing an non ZeRO model with optimizer class"
)
if
not
use_zero
:
if
not
use_zero
:
if
is_using_sequence
():
if
is_using_sequence
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment