Unverified Commit 01c7fb04 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[DeepSpeed] simplify init (#10762)

parent 0486ccdd
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
import re import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from .utils import logging from .utils import logging
from .utils.versions import require_version from .utils.versions import require_version
...@@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps): ...@@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps):
"enabled": True, "enabled": True,
} }
# for clarity extract the specific cl args that are being passed to deepspeed
ds_args = dict(local_rank=args.local_rank)
# keep for quick debug: # keep for quick debug:
# from pprint import pprint; pprint(config) # from pprint import pprint; pprint(config)
# init that takes part of the config via `args`, and the bulk of it via `config_params` # init that takes part of the config via `args`, and the bulk of it via `config_params`
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize( model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(**ds_args), # expects an obj
model=model, model=model,
model_parameters=model_parameters, model_parameters=model_parameters,
config_params=config, config_params=config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment